Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f1b8cd2c
Unverified
Commit
f1b8cd2c
authored
Sep 24, 2020
by
Ningxin Zheng
Committed by
GitHub
Sep 24, 2020
Browse files
Add the support for aten::mul operator. (#2905)
parent
986d58c1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
1 deletion
+8
-1
src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py
...k/pynni/nni/compression/torch/speedup/compress_modules.py
+1
-0
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
+6
-0
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
+1
-1
No files found.
src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py
View file @
f1b8cd2c
...
@@ -15,6 +15,7 @@ replace_module = {
...
@@ -15,6 +15,7 @@ replace_module = {
'AdaptiveAvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'AdaptiveAvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'ReLU'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'ReLU'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'ReLU6'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'ReLU6'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'Sigmoid'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'Linear'
:
lambda
module
,
mask
:
replace_linear
(
module
,
mask
),
'Linear'
:
lambda
module
,
mask
:
replace_linear
(
module
,
mask
),
'Dropout'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'Dropout'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'Dropout2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'Dropout2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
...
...
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
View file @
f1b8cd2c
...
@@ -221,12 +221,14 @@ Infer output and weight shape of a module/function from its input shape
...
@@ -221,12 +221,14 @@ Infer output and weight shape of a module/function from its input shape
infer_from_inshape
=
{
infer_from_inshape
=
{
'ReLU'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU6'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU6'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'Sigmoid'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::tanh'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::tanh'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::tanh_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::tanh_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::hardtanh'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::hardtanh'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::hardtanh_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::hardtanh_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::sigmoid'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_inshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_inshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
...
@@ -243,6 +245,10 @@ infer_from_inshape = {
...
@@ -243,6 +245,10 @@ infer_from_inshape = {
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_inshape
(
module_masks
,
mask
),
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_inshape
(
module_masks
,
mask
),
'aten::add_'
:
lambda
module_masks
,
mask
:
add_inshape
(
module_masks
,
mask
),
'aten::add_'
:
lambda
module_masks
,
mask
:
add_inshape
(
module_masks
,
mask
),
'aten::add'
:
lambda
module_mask
,
mask
:
add_inshape
(
module_mask
,
mask
),
'aten::add'
:
lambda
module_mask
,
mask
:
add_inshape
(
module_mask
,
mask
),
# mul has the similar behaviour with add, they both request
# the input tesors to have the same shape
'aten::mul'
:
lambda
module_mask
,
mask
:
add_inshape
(
module_mask
,
mask
),
'aten::mul_'
:
lambda
module_mask
,
mask
:
add_inshape
(
module_mask
,
mask
),
'aten::cat'
:
lambda
module_mask
,
mask
,
cat_info
,
last_visited
:
cat_inshape
(
module_mask
,
mask
,
cat_info
,
last_visited
),
'aten::cat'
:
lambda
module_mask
,
mask
,
cat_info
,
last_visited
:
cat_inshape
(
module_mask
,
mask
,
cat_info
,
last_visited
),
'aten::mean'
:
lambda
module_masks
,
mask
,
shape
:
mean_inshape
(
module_masks
,
mask
,
shape
),
'aten::mean'
:
lambda
module_masks
,
mask
,
shape
:
mean_inshape
(
module_masks
,
mask
,
shape
),
'Dropout'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
'Dropout'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
...
...
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
View file @
f1b8cd2c
...
@@ -284,7 +284,7 @@ class ChannelMaskConflict(MaskFix):
...
@@ -284,7 +284,7 @@ class ChannelMaskConflict(MaskFix):
ori_channels
=
w_shape
[
0
]
ori_channels
=
w_shape
[
0
]
for
i
in
channel_remain
:
for
i
in
channel_remain
:
mask
[
'weight'
][
i
]
=
torch
.
ones
(
w_shape
[
1
:])
mask
[
'weight'
][
i
]
=
torch
.
ones
(
w_shape
[
1
:])
if
hasattr
(
mask
,
'bias'
)
:
if
'bias'
in
mask
and
mask
[
'bias'
]
is
not
None
:
mask
[
'bias'
][
i
]
=
1
mask
[
'bias'
][
i
]
=
1
_logger
.
info
(
','
.
join
(
dset
))
_logger
.
info
(
','
.
join
(
dset
))
_logger
.
info
(
'Pruned Filters after fixing conflict:'
)
_logger
.
info
(
'Pruned Filters after fixing conflict:'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment