Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f1b8cd2c
"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "a4c028ce0cc7e2947bfa75e79b3a1971d3f25e33"
Unverified
Commit
f1b8cd2c
authored
Sep 24, 2020
by
Ningxin Zheng
Committed by
GitHub
Sep 24, 2020
Browse files
Add the support for aten::mul operator. (#2905)
parent
986d58c1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
1 deletion
+8
-1
src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py
...k/pynni/nni/compression/torch/speedup/compress_modules.py
+1
-0
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
+6
-0
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
+1
-1
No files found.
src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py
View file @
f1b8cd2c
...
...
@@ -15,6 +15,7 @@ replace_module = {
'AdaptiveAvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'ReLU'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'ReLU6'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'Sigmoid'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'Linear'
:
lambda
module
,
mask
:
replace_linear
(
module
,
mask
),
'Dropout'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'Dropout2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
...
...
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
View file @
f1b8cd2c
...
...
@@ -221,12 +221,14 @@ Infer output and weight shape of a module/function from its input shape
infer_from_inshape
=
{
'ReLU'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU6'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'Sigmoid'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::tanh'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::tanh_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::hardtanh'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::hardtanh_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::sigmoid'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_inshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
...
...
@@ -243,6 +245,10 @@ infer_from_inshape = {
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_inshape
(
module_masks
,
mask
),
'aten::add_'
:
lambda
module_masks
,
mask
:
add_inshape
(
module_masks
,
mask
),
'aten::add'
:
lambda
module_mask
,
mask
:
add_inshape
(
module_mask
,
mask
),
# mul has the similar behaviour with add, they both request
# the input tesors to have the same shape
'aten::mul'
:
lambda
module_mask
,
mask
:
add_inshape
(
module_mask
,
mask
),
'aten::mul_'
:
lambda
module_mask
,
mask
:
add_inshape
(
module_mask
,
mask
),
'aten::cat'
:
lambda
module_mask
,
mask
,
cat_info
,
last_visited
:
cat_inshape
(
module_mask
,
mask
,
cat_info
,
last_visited
),
'aten::mean'
:
lambda
module_masks
,
mask
,
shape
:
mean_inshape
(
module_masks
,
mask
,
shape
),
'Dropout'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
...
...
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
View file @
f1b8cd2c
...
...
@@ -284,7 +284,7 @@ class ChannelMaskConflict(MaskFix):
ori_channels
=
w_shape
[
0
]
for
i
in
channel_remain
:
mask
[
'weight'
][
i
]
=
torch
.
ones
(
w_shape
[
1
:])
if
hasattr
(
mask
,
'bias'
)
:
if
'bias'
in
mask
and
mask
[
'bias'
]
is
not
None
:
mask
[
'bias'
][
i
]
=
1
_logger
.
info
(
','
.
join
(
dset
))
_logger
.
info
(
'Pruned Filters after fixing conflict:'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment