Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
607d6a91
"...ootdiffusion_pytorch.git" did not exist on "c50c08d9353cbdfe8c6a11c97935571b303c0d1d"
Unverified
Commit
607d6a91
authored
Sep 22, 2021
by
Ningxin Zheng
Committed by
GitHub
Sep 22, 2021
Browse files
Error code for Speedup Module (#4173)
parent
8b61e774
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
79 additions
and
27 deletions
+79
-27
nni/compression/pytorch/speedup/compress_modules.py
nni/compression/pytorch/speedup/compress_modules.py
+47
-27
nni/compression/pytorch/speedup/compressor.py
nni/compression/pytorch/speedup/compressor.py
+1
-0
nni/compression/pytorch/speedup/error_code.py
nni/compression/pytorch/speedup/error_code.py
+31
-0
No files found.
nni/compression/pytorch/speedup/compress_modules.py
View file @
607d6a91
...
...
@@ -4,6 +4,7 @@
import
logging
import
torch
import
torch.nn
as
nn
from
.error_code
import
EmptyLayerError
,
ShapeMisMatchError
,
InputsNumberError
,
OutputTypeError
,
UnBalancedGroupError
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -44,7 +45,6 @@ replace_module = {
}
def
convert_to_coarse_mask
(
t_mask
,
dim
):
"""
Convert the mask tensor to the coarse-grained mask tensor.
...
...
@@ -87,6 +87,7 @@ def no_replace(module, masks):
_logger
.
debug
(
"no need to replace"
)
return
module
def
replace_prelu
(
prelu
,
masks
):
"""
Parameters
...
...
@@ -102,8 +103,11 @@ def replace_prelu(prelu, masks):
The new prelu module
"""
in_masks
,
output_mask
,
weight_mask
=
masks
assert
len
(
in_masks
)
==
1
assert
isinstance
(
output_mask
,
torch
.
Tensor
)
if
len
(
in_masks
)
!=
1
:
raise
InputsNumberError
()
if
not
isinstance
(
output_mask
,
torch
.
Tensor
):
raise
OutputTypeError
(
type
(
output_mask
),
torch
.
Tensor
)
in_mask
=
in_masks
[
0
]
weight_mask
=
weight_mask
[
'weight'
]
pruned_in
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
...
...
@@ -112,13 +116,17 @@ def replace_prelu(prelu, masks):
n_remained_out
=
weight_mask
.
size
(
0
)
-
pruned_out
.
size
(
0
)
remained_in
,
remained_out
=
remained_in
.
to
(
prelu
.
weight
.
device
),
remained_out
.
to
(
prelu
.
weight
.
device
)
assert
n_remained_in
==
n_remained_out
if
n_remained_in
!=
n_remained_out
:
raise
ShapeMisMatchError
()
if
n_remained_in
==
0
:
return
torch
.
nn
.
Identity
()
new_prelu
=
torch
.
nn
.
PReLU
(
n_remained_in
)
new_prelu
.
weight
.
data
=
torch
.
index_select
(
prelu
.
weight
.
data
,
0
,
remained_in
)
new_prelu
.
weight
.
data
=
torch
.
index_select
(
prelu
.
weight
.
data
,
0
,
remained_in
)
return
new_prelu
def
replace_linear
(
linear
,
masks
):
"""
This function will replace the original linear according to
...
...
@@ -142,8 +150,11 @@ def replace_linear(linear, masks):
"""
in_masks
,
output_mask
,
weight_mask
=
masks
assert
isinstance
(
linear
,
nn
.
Linear
)
assert
len
(
in_masks
)
==
1
assert
isinstance
(
output_mask
,
torch
.
Tensor
)
if
len
(
in_masks
)
!=
1
:
raise
InputsNumberError
()
if
not
isinstance
(
output_mask
,
torch
.
Tensor
):
raise
OutputTypeError
(
type
(
output_mask
),
torch
.
Tensor
)
in_mask
=
in_masks
[
0
]
weight_mask
=
weight_mask
[
'weight'
]
...
...
@@ -199,7 +210,8 @@ def replace_batchnorm1d(norm, masks):
# N, C, H, W
_
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
_
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
assert
remained_in
.
size
(
0
)
==
remained_out
.
size
(
0
)
if
remained_in
.
size
(
0
)
!=
remained_out
.
size
(
0
):
raise
ShapeMisMatchError
()
num_features
=
remained_in
.
size
(
0
)
_logger
.
info
(
"replace batchnorm1d with num_features: %d"
,
num_features
)
...
...
@@ -241,7 +253,8 @@ def replace_batchnorm2d(norm, masks):
# N, C, H, W
_
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
_
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
assert
remained_in
.
size
(
0
)
==
remained_out
.
size
(
0
)
if
remained_in
.
size
(
0
)
!=
remained_out
.
size
(
0
):
raise
ShapeMisMatchError
()
num_features
=
remained_in
.
size
(
0
)
_logger
.
info
(
"replace batchnorm2d with num_features: %d"
,
num_features
)
...
...
@@ -261,7 +274,6 @@ def replace_batchnorm2d(norm, masks):
return
new_norm
def
replace_conv2d
(
conv
,
masks
):
"""
Replace the original conv with a new one according to the infered
...
...
@@ -285,7 +297,8 @@ def replace_conv2d(conv, masks):
in_masks
,
output_mask
,
weight_masks
=
masks
assert
isinstance
(
conv
,
nn
.
Conv2d
)
# the conv layer should only have one input tensor
assert
len
(
in_masks
)
==
1
if
len
(
in_masks
)
!=
1
:
raise
InputsNumberError
()
in_mask
=
in_masks
[
0
]
...
...
@@ -296,8 +309,8 @@ def replace_conv2d(conv, masks):
n_remained_in
=
weight_mask
.
size
(
1
)
*
conv
.
groups
-
pruned_in
.
size
(
0
)
n_remained_out
=
weight_mask
.
size
(
0
)
-
pruned_out
.
size
(
0
)
assert
n_remained_
in
=
=
remained_
in
.
size
(
0
)
assert
n_remained_out
==
remained_out
.
size
(
0
)
if
n_remained_in
!=
remained_in
.
size
(
0
)
or
n_remained_
out
!
=
remained_
out
.
size
(
0
)
:
raise
ShapeMisMatchError
(
)
k_size1
,
k_size2
=
conv
.
kernel_size
# Note: We should resolve the group dependency of the conv layers before
...
...
@@ -331,9 +344,10 @@ def replace_conv2d(conv, masks):
tmp_weight
=
torch
.
ones
(
n_remained_out
,
new_inchannel_step
,
k_size1
,
k_size2
)
tmp_weight
=
tmp_weight
.
to
(
conv
.
weight
.
device
)
assert
n_remained_in
%
new_inchannel_step
==
0
assert
n_remained_out
%
new_outchannel_step
==
0
if
new_inchannel_step
==
0
or
new_outchannel_step
==
0
:
raise
EmptyLayerError
()
if
n_remained_in
%
new_inchannel_step
!=
0
or
n_remained_out
%
new_outchannel_step
!=
0
:
raise
UnBalancedGroupError
()
new_groups
=
0
for
groupid
in
range
(
conv
.
groups
):
...
...
@@ -352,8 +366,9 @@ def replace_conv2d(conv, masks):
assert
len
(
current_output_index
)
==
0
continue
# check if the number of remained channel of each group are the same
assert
len
(
current_input_index
)
==
new_inchannel_step
assert
len
(
current_output_index
)
==
new_outchannel_step
if
len
(
current_input_index
)
!=
new_inchannel_step
or
len
(
current_output_index
)
!=
new_outchannel_step
:
raise
UnBalancedGroupError
()
# copy the weight into tmp_weight
new_out_start
=
new_outchannel_step
*
new_groups
new_out_end
=
new_out_start
+
new_outchannel_step
...
...
@@ -386,7 +401,6 @@ def replace_conv2d(conv, masks):
new_conv
.
bias
.
data
.
copy_
(
torch
.
index_select
(
conv
.
bias
.
data
,
0
,
remained_out
))
return
new_conv
...
...
@@ -410,7 +424,8 @@ def replace_convtranspose2d(convtrans, masks):
"""
in_masks
,
output_mask
,
weight_masks
=
masks
assert
isinstance
(
convtrans
,
torch
.
nn
.
ConvTranspose2d
)
assert
len
(
in_masks
)
==
1
if
len
(
in_masks
)
!=
1
:
raise
InputsNumberError
()
in_mask
=
in_masks
[
0
]
weight_mask
=
weight_masks
[
'weight'
]
...
...
@@ -420,8 +435,9 @@ def replace_convtranspose2d(convtrans, masks):
n_remained_in
=
weight_mask
.
size
(
0
)
-
pruned_in
.
size
(
0
)
n_remained_out
=
weight_mask
.
size
(
1
)
*
convtrans
.
groups
-
pruned_out
.
size
(
0
)
assert
n_remained_in
==
remained_in
.
size
(
0
)
assert
n_remained_out
==
remained_out
.
size
(
0
)
if
n_remained_in
!=
remained_in
.
size
(
0
)
or
n_remained_out
!=
remained_out
.
size
(
0
):
raise
ShapeMisMatchError
()
k_size1
,
k_size2
=
convtrans
.
kernel_size
# Note: we should resolve the group dependency of the convtrans layers before
# run into this function
...
...
@@ -448,8 +464,10 @@ def replace_convtranspose2d(convtrans, masks):
n_remained_in
,
new_outchannel_step
,
k_size1
,
k_size2
)
tmp_weight
=
tmp_weight
.
to
(
convtrans
.
weight
.
device
)
assert
n_remained_in
%
new_inchannel_step
==
0
assert
n_remained_out
%
new_outchannel_step
==
0
if
new_inchannel_step
==
0
or
new_outchannel_step
==
0
:
raise
EmptyLayerError
()
if
n_remained_in
%
new_inchannel_step
!=
0
or
n_remained_out
%
new_outchannel_step
!=
0
:
raise
UnBalancedGroupError
()
new_groups
=
0
for
groupid
in
range
(
convtrans
.
groups
):
...
...
@@ -471,8 +489,9 @@ def replace_convtranspose2d(convtrans, masks):
assert
len
(
current_output_index
)
==
0
continue
# check if the number of remained channel of each group are the same
assert
len
(
current_input_index
)
==
new_inchannel_step
assert
len
(
current_output_index
)
==
new_outchannel_step
if
len
(
current_input_index
)
!=
new_inchannel_step
or
len
(
current_output_index
)
!=
new_outchannel_step
:
raise
UnBalancedGroupError
()
# copy the weight into tmp_weight
new_in_start
=
new_inchannel_step
*
new_groups
new_in_end
=
new_in_start
+
new_inchannel_step
...
...
@@ -505,7 +524,8 @@ def replace_convtranspose2d(convtrans, masks):
def
replace_layernorm
(
layernorm
,
masks
):
in_masks
,
_
,
_
=
masks
assert
isinstance
(
layernorm
,
nn
.
LayerNorm
)
assert
len
(
in_masks
)
==
1
if
len
(
in_masks
)
!=
1
:
raise
InputsNumberError
()
in_mask
=
in_masks
[
0
]
dim_n
=
len
(
in_mask
.
size
())
new_shape
=
[]
...
...
nni/compression/pytorch/speedup/compressor.py
View file @
607d6a91
...
...
@@ -15,6 +15,7 @@ from .infer_mask import AutoMaskInference
from
.jit_translate
import
jit_to_python_function
from
..utils
import
rand_like_with_shape
_logger
=
logging
.
getLogger
(
__name__
)
_logger
.
setLevel
(
logging
.
INFO
)
...
...
nni/compression/pytorch/speedup/error_code.py
0 → 100644
View file @
607d6a91
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Error Code of the speedup
class
SpeedupError
(
Exception
):
def
__init__
(
self
,
msg
):
self
.
msg
=
msg
def
__str__
(
self
):
return
str
(
self
.
msg
)
class
EmptyLayerError
(
SpeedupError
):
def
__init__
(
self
):
super
(
EmptyLayerError
,
self
).
__init__
(
"Pruning a Layer to empty is not legal"
)
class
ShapeMisMatchError
(
SpeedupError
):
def
__init__
(
self
):
super
(
ShapeMisMatchError
,
self
).
__init__
(
"Shape mismatch!"
)
class
InputsNumberError
(
SpeedupError
):
def
__init__
(
self
):
super
(
InputsNumberError
,
self
).
__init__
(
"The number of the inputs of the target OP is wrong"
)
class
OutputTypeError
(
SpeedupError
):
def
__init__
(
self
,
current_type
,
target_type
):
msg
=
f
"The output type should be
{
str
(
target_type
)
}
, but
{
str
(
current_type
)
}
founded"
super
(
OutputTypeError
,
self
).
__init__
(
msg
)
class
UnBalancedGroupError
(
SpeedupError
):
def
__init__
(
self
):
msg
=
"The number remained filters in each group is different"
super
(
UnBalancedGroupError
,
self
).
__init__
(
msg
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment