Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
607d6a91
Unverified
Commit
607d6a91
authored
Sep 22, 2021
by
Ningxin Zheng
Committed by
GitHub
Sep 22, 2021
Browse files
Error code for Speedup Module (#4173)
parent
8b61e774
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
79 additions
and
27 deletions
+79
-27
nni/compression/pytorch/speedup/compress_modules.py
nni/compression/pytorch/speedup/compress_modules.py
+47
-27
nni/compression/pytorch/speedup/compressor.py
nni/compression/pytorch/speedup/compressor.py
+1
-0
nni/compression/pytorch/speedup/error_code.py
nni/compression/pytorch/speedup/error_code.py
+31
-0
No files found.
nni/compression/pytorch/speedup/compress_modules.py
View file @
607d6a91
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
logging
import
logging
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.error_code
import
EmptyLayerError
,
ShapeMisMatchError
,
InputsNumberError
,
OutputTypeError
,
UnBalancedGroupError
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -44,7 +45,6 @@ replace_module = {
...
@@ -44,7 +45,6 @@ replace_module = {
}
}
def
convert_to_coarse_mask
(
t_mask
,
dim
):
def
convert_to_coarse_mask
(
t_mask
,
dim
):
"""
"""
Convert the mask tensor to the coarse-grained mask tensor.
Convert the mask tensor to the coarse-grained mask tensor.
...
@@ -87,6 +87,7 @@ def no_replace(module, masks):
...
@@ -87,6 +87,7 @@ def no_replace(module, masks):
_logger
.
debug
(
"no need to replace"
)
_logger
.
debug
(
"no need to replace"
)
return
module
return
module
def
replace_prelu
(
prelu
,
masks
):
def
replace_prelu
(
prelu
,
masks
):
"""
"""
Parameters
Parameters
...
@@ -102,8 +103,11 @@ def replace_prelu(prelu, masks):
...
@@ -102,8 +103,11 @@ def replace_prelu(prelu, masks):
The new prelu module
The new prelu module
"""
"""
in_masks
,
output_mask
,
weight_mask
=
masks
in_masks
,
output_mask
,
weight_mask
=
masks
assert
len
(
in_masks
)
==
1
if
len
(
in_masks
)
!=
1
:
assert
isinstance
(
output_mask
,
torch
.
Tensor
)
raise
InputsNumberError
()
if
not
isinstance
(
output_mask
,
torch
.
Tensor
):
raise
OutputTypeError
(
type
(
output_mask
),
torch
.
Tensor
)
in_mask
=
in_masks
[
0
]
in_mask
=
in_masks
[
0
]
weight_mask
=
weight_mask
[
'weight'
]
weight_mask
=
weight_mask
[
'weight'
]
pruned_in
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
pruned_in
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
...
@@ -112,13 +116,17 @@ def replace_prelu(prelu, masks):
...
@@ -112,13 +116,17 @@ def replace_prelu(prelu, masks):
n_remained_out
=
weight_mask
.
size
(
0
)
-
pruned_out
.
size
(
0
)
n_remained_out
=
weight_mask
.
size
(
0
)
-
pruned_out
.
size
(
0
)
remained_in
,
remained_out
=
remained_in
.
to
(
remained_in
,
remained_out
=
remained_in
.
to
(
prelu
.
weight
.
device
),
remained_out
.
to
(
prelu
.
weight
.
device
)
prelu
.
weight
.
device
),
remained_out
.
to
(
prelu
.
weight
.
device
)
assert
n_remained_in
==
n_remained_out
if
n_remained_in
!=
n_remained_out
:
raise
ShapeMisMatchError
()
if
n_remained_in
==
0
:
if
n_remained_in
==
0
:
return
torch
.
nn
.
Identity
()
return
torch
.
nn
.
Identity
()
new_prelu
=
torch
.
nn
.
PReLU
(
n_remained_in
)
new_prelu
=
torch
.
nn
.
PReLU
(
n_remained_in
)
new_prelu
.
weight
.
data
=
torch
.
index_select
(
prelu
.
weight
.
data
,
0
,
remained_in
)
new_prelu
.
weight
.
data
=
torch
.
index_select
(
prelu
.
weight
.
data
,
0
,
remained_in
)
return
new_prelu
return
new_prelu
def
replace_linear
(
linear
,
masks
):
def
replace_linear
(
linear
,
masks
):
"""
"""
This function will replace the original linear according to
This function will replace the original linear according to
...
@@ -142,8 +150,11 @@ def replace_linear(linear, masks):
...
@@ -142,8 +150,11 @@ def replace_linear(linear, masks):
"""
"""
in_masks
,
output_mask
,
weight_mask
=
masks
in_masks
,
output_mask
,
weight_mask
=
masks
assert
isinstance
(
linear
,
nn
.
Linear
)
assert
isinstance
(
linear
,
nn
.
Linear
)
assert
len
(
in_masks
)
==
1
if
len
(
in_masks
)
!=
1
:
assert
isinstance
(
output_mask
,
torch
.
Tensor
)
raise
InputsNumberError
()
if
not
isinstance
(
output_mask
,
torch
.
Tensor
):
raise
OutputTypeError
(
type
(
output_mask
),
torch
.
Tensor
)
in_mask
=
in_masks
[
0
]
in_mask
=
in_masks
[
0
]
weight_mask
=
weight_mask
[
'weight'
]
weight_mask
=
weight_mask
[
'weight'
]
...
@@ -199,7 +210,8 @@ def replace_batchnorm1d(norm, masks):
...
@@ -199,7 +210,8 @@ def replace_batchnorm1d(norm, masks):
# N, C, H, W
# N, C, H, W
_
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
_
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
_
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
_
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
assert
remained_in
.
size
(
0
)
==
remained_out
.
size
(
0
)
if
remained_in
.
size
(
0
)
!=
remained_out
.
size
(
0
):
raise
ShapeMisMatchError
()
num_features
=
remained_in
.
size
(
0
)
num_features
=
remained_in
.
size
(
0
)
_logger
.
info
(
"replace batchnorm1d with num_features: %d"
,
num_features
)
_logger
.
info
(
"replace batchnorm1d with num_features: %d"
,
num_features
)
...
@@ -241,7 +253,8 @@ def replace_batchnorm2d(norm, masks):
...
@@ -241,7 +253,8 @@ def replace_batchnorm2d(norm, masks):
# N, C, H, W
# N, C, H, W
_
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
_
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
_
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
_
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
assert
remained_in
.
size
(
0
)
==
remained_out
.
size
(
0
)
if
remained_in
.
size
(
0
)
!=
remained_out
.
size
(
0
):
raise
ShapeMisMatchError
()
num_features
=
remained_in
.
size
(
0
)
num_features
=
remained_in
.
size
(
0
)
_logger
.
info
(
"replace batchnorm2d with num_features: %d"
,
num_features
)
_logger
.
info
(
"replace batchnorm2d with num_features: %d"
,
num_features
)
...
@@ -261,7 +274,6 @@ def replace_batchnorm2d(norm, masks):
...
@@ -261,7 +274,6 @@ def replace_batchnorm2d(norm, masks):
return
new_norm
return
new_norm
def
replace_conv2d
(
conv
,
masks
):
def
replace_conv2d
(
conv
,
masks
):
"""
"""
Replace the original conv with a new one according to the infered
Replace the original conv with a new one according to the infered
...
@@ -285,7 +297,8 @@ def replace_conv2d(conv, masks):
...
@@ -285,7 +297,8 @@ def replace_conv2d(conv, masks):
in_masks
,
output_mask
,
weight_masks
=
masks
in_masks
,
output_mask
,
weight_masks
=
masks
assert
isinstance
(
conv
,
nn
.
Conv2d
)
assert
isinstance
(
conv
,
nn
.
Conv2d
)
# the conv layer should only have one input tensor
# the conv layer should only have one input tensor
assert
len
(
in_masks
)
==
1
if
len
(
in_masks
)
!=
1
:
raise
InputsNumberError
()
in_mask
=
in_masks
[
0
]
in_mask
=
in_masks
[
0
]
...
@@ -296,8 +309,8 @@ def replace_conv2d(conv, masks):
...
@@ -296,8 +309,8 @@ def replace_conv2d(conv, masks):
n_remained_in
=
weight_mask
.
size
(
1
)
*
conv
.
groups
-
pruned_in
.
size
(
0
)
n_remained_in
=
weight_mask
.
size
(
1
)
*
conv
.
groups
-
pruned_in
.
size
(
0
)
n_remained_out
=
weight_mask
.
size
(
0
)
-
pruned_out
.
size
(
0
)
n_remained_out
=
weight_mask
.
size
(
0
)
-
pruned_out
.
size
(
0
)
assert
n_remained_
in
=
=
remained_
in
.
size
(
0
)
if
n_remained_in
!=
remained_in
.
size
(
0
)
or
n_remained_
out
!
=
remained_
out
.
size
(
0
)
:
assert
n_remained_out
==
remained_out
.
size
(
0
)
raise
ShapeMisMatchError
(
)
k_size1
,
k_size2
=
conv
.
kernel_size
k_size1
,
k_size2
=
conv
.
kernel_size
# Note: We should resolve the group dependency of the conv layers before
# Note: We should resolve the group dependency of the conv layers before
...
@@ -331,9 +344,10 @@ def replace_conv2d(conv, masks):
...
@@ -331,9 +344,10 @@ def replace_conv2d(conv, masks):
tmp_weight
=
torch
.
ones
(
tmp_weight
=
torch
.
ones
(
n_remained_out
,
new_inchannel_step
,
k_size1
,
k_size2
)
n_remained_out
,
new_inchannel_step
,
k_size1
,
k_size2
)
tmp_weight
=
tmp_weight
.
to
(
conv
.
weight
.
device
)
tmp_weight
=
tmp_weight
.
to
(
conv
.
weight
.
device
)
if
new_inchannel_step
==
0
or
new_outchannel_step
==
0
:
assert
n_remained_in
%
new_inchannel_step
==
0
raise
EmptyLayerError
()
assert
n_remained_out
%
new_outchannel_step
==
0
if
n_remained_in
%
new_inchannel_step
!=
0
or
n_remained_out
%
new_outchannel_step
!=
0
:
raise
UnBalancedGroupError
()
new_groups
=
0
new_groups
=
0
for
groupid
in
range
(
conv
.
groups
):
for
groupid
in
range
(
conv
.
groups
):
...
@@ -352,8 +366,9 @@ def replace_conv2d(conv, masks):
...
@@ -352,8 +366,9 @@ def replace_conv2d(conv, masks):
assert
len
(
current_output_index
)
==
0
assert
len
(
current_output_index
)
==
0
continue
continue
# check if the number of remained channel of each group are the same
# check if the number of remained channel of each group are the same
assert
len
(
current_input_index
)
==
new_inchannel_step
if
len
(
current_input_index
)
!=
new_inchannel_step
or
len
(
current_output_index
)
!=
new_outchannel_step
:
assert
len
(
current_output_index
)
==
new_outchannel_step
raise
UnBalancedGroupError
()
# copy the weight into tmp_weight
# copy the weight into tmp_weight
new_out_start
=
new_outchannel_step
*
new_groups
new_out_start
=
new_outchannel_step
*
new_groups
new_out_end
=
new_out_start
+
new_outchannel_step
new_out_end
=
new_out_start
+
new_outchannel_step
...
@@ -386,7 +401,6 @@ def replace_conv2d(conv, masks):
...
@@ -386,7 +401,6 @@ def replace_conv2d(conv, masks):
new_conv
.
bias
.
data
.
copy_
(
torch
.
index_select
(
new_conv
.
bias
.
data
.
copy_
(
torch
.
index_select
(
conv
.
bias
.
data
,
0
,
remained_out
))
conv
.
bias
.
data
,
0
,
remained_out
))
return
new_conv
return
new_conv
...
@@ -410,7 +424,8 @@ def replace_convtranspose2d(convtrans, masks):
...
@@ -410,7 +424,8 @@ def replace_convtranspose2d(convtrans, masks):
"""
"""
in_masks
,
output_mask
,
weight_masks
=
masks
in_masks
,
output_mask
,
weight_masks
=
masks
assert
isinstance
(
convtrans
,
torch
.
nn
.
ConvTranspose2d
)
assert
isinstance
(
convtrans
,
torch
.
nn
.
ConvTranspose2d
)
assert
len
(
in_masks
)
==
1
if
len
(
in_masks
)
!=
1
:
raise
InputsNumberError
()
in_mask
=
in_masks
[
0
]
in_mask
=
in_masks
[
0
]
weight_mask
=
weight_masks
[
'weight'
]
weight_mask
=
weight_masks
[
'weight'
]
...
@@ -420,8 +435,9 @@ def replace_convtranspose2d(convtrans, masks):
...
@@ -420,8 +435,9 @@ def replace_convtranspose2d(convtrans, masks):
n_remained_in
=
weight_mask
.
size
(
0
)
-
pruned_in
.
size
(
0
)
n_remained_in
=
weight_mask
.
size
(
0
)
-
pruned_in
.
size
(
0
)
n_remained_out
=
weight_mask
.
size
(
n_remained_out
=
weight_mask
.
size
(
1
)
*
convtrans
.
groups
-
pruned_out
.
size
(
0
)
1
)
*
convtrans
.
groups
-
pruned_out
.
size
(
0
)
assert
n_remained_in
==
remained_in
.
size
(
0
)
if
n_remained_in
!=
remained_in
.
size
(
0
)
or
n_remained_out
!=
remained_out
.
size
(
0
):
assert
n_remained_out
==
remained_out
.
size
(
0
)
raise
ShapeMisMatchError
()
k_size1
,
k_size2
=
convtrans
.
kernel_size
k_size1
,
k_size2
=
convtrans
.
kernel_size
# Note: we should resolve the group dependency of the convtrans layers before
# Note: we should resolve the group dependency of the convtrans layers before
# run into this function
# run into this function
...
@@ -448,8 +464,10 @@ def replace_convtranspose2d(convtrans, masks):
...
@@ -448,8 +464,10 @@ def replace_convtranspose2d(convtrans, masks):
n_remained_in
,
new_outchannel_step
,
k_size1
,
k_size2
)
n_remained_in
,
new_outchannel_step
,
k_size1
,
k_size2
)
tmp_weight
=
tmp_weight
.
to
(
convtrans
.
weight
.
device
)
tmp_weight
=
tmp_weight
.
to
(
convtrans
.
weight
.
device
)
assert
n_remained_in
%
new_inchannel_step
==
0
if
new_inchannel_step
==
0
or
new_outchannel_step
==
0
:
assert
n_remained_out
%
new_outchannel_step
==
0
raise
EmptyLayerError
()
if
n_remained_in
%
new_inchannel_step
!=
0
or
n_remained_out
%
new_outchannel_step
!=
0
:
raise
UnBalancedGroupError
()
new_groups
=
0
new_groups
=
0
for
groupid
in
range
(
convtrans
.
groups
):
for
groupid
in
range
(
convtrans
.
groups
):
...
@@ -471,8 +489,9 @@ def replace_convtranspose2d(convtrans, masks):
...
@@ -471,8 +489,9 @@ def replace_convtranspose2d(convtrans, masks):
assert
len
(
current_output_index
)
==
0
assert
len
(
current_output_index
)
==
0
continue
continue
# check if the number of remained channel of each group are the same
# check if the number of remained channel of each group are the same
assert
len
(
current_input_index
)
==
new_inchannel_step
if
len
(
current_input_index
)
!=
new_inchannel_step
or
len
(
current_output_index
)
!=
new_outchannel_step
:
assert
len
(
current_output_index
)
==
new_outchannel_step
raise
UnBalancedGroupError
()
# copy the weight into tmp_weight
# copy the weight into tmp_weight
new_in_start
=
new_inchannel_step
*
new_groups
new_in_start
=
new_inchannel_step
*
new_groups
new_in_end
=
new_in_start
+
new_inchannel_step
new_in_end
=
new_in_start
+
new_inchannel_step
...
@@ -505,7 +524,8 @@ def replace_convtranspose2d(convtrans, masks):
...
@@ -505,7 +524,8 @@ def replace_convtranspose2d(convtrans, masks):
def
replace_layernorm
(
layernorm
,
masks
):
def
replace_layernorm
(
layernorm
,
masks
):
in_masks
,
_
,
_
=
masks
in_masks
,
_
,
_
=
masks
assert
isinstance
(
layernorm
,
nn
.
LayerNorm
)
assert
isinstance
(
layernorm
,
nn
.
LayerNorm
)
assert
len
(
in_masks
)
==
1
if
len
(
in_masks
)
!=
1
:
raise
InputsNumberError
()
in_mask
=
in_masks
[
0
]
in_mask
=
in_masks
[
0
]
dim_n
=
len
(
in_mask
.
size
())
dim_n
=
len
(
in_mask
.
size
())
new_shape
=
[]
new_shape
=
[]
...
...
nni/compression/pytorch/speedup/compressor.py
View file @
607d6a91
...
@@ -15,6 +15,7 @@ from .infer_mask import AutoMaskInference
...
@@ -15,6 +15,7 @@ from .infer_mask import AutoMaskInference
from
.jit_translate
import
jit_to_python_function
from
.jit_translate
import
jit_to_python_function
from
..utils
import
rand_like_with_shape
from
..utils
import
rand_like_with_shape
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
_logger
.
setLevel
(
logging
.
INFO
)
_logger
.
setLevel
(
logging
.
INFO
)
...
...
nni/compression/pytorch/speedup/error_code.py
0 → 100644
View file @
607d6a91
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Error Code of the speedup
class
SpeedupError
(
Exception
):
def
__init__
(
self
,
msg
):
self
.
msg
=
msg
def
__str__
(
self
):
return
str
(
self
.
msg
)
class
EmptyLayerError
(
SpeedupError
):
def
__init__
(
self
):
super
(
EmptyLayerError
,
self
).
__init__
(
"Pruning a Layer to empty is not legal"
)
class
ShapeMisMatchError
(
SpeedupError
):
def
__init__
(
self
):
super
(
ShapeMisMatchError
,
self
).
__init__
(
"Shape mismatch!"
)
class
InputsNumberError
(
SpeedupError
):
def
__init__
(
self
):
super
(
InputsNumberError
,
self
).
__init__
(
"The number of the inputs of the target OP is wrong"
)
class
OutputTypeError
(
SpeedupError
):
def
__init__
(
self
,
current_type
,
target_type
):
msg
=
f
"The output type should be
{
str
(
target_type
)
}
, but
{
str
(
current_type
)
}
founded"
super
(
OutputTypeError
,
self
).
__init__
(
msg
)
class
UnBalancedGroupError
(
SpeedupError
):
def
__init__
(
self
):
msg
=
"The number remained filters in each group is different"
super
(
UnBalancedGroupError
,
self
).
__init__
(
msg
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment