Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
6aaa2d92
Unverified
Commit
6aaa2d92
authored
Mar 28, 2022
by
J-shang
Committed by
GitHub
Mar 28, 2022
Browse files
[Compression] support aten::sub & aten::constant_pad_nd in speedup (#4644)
parent
45cefc7c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
1 deletion
+41
-1
nni/compression/pytorch/speedup/compress_modules.py
nni/compression/pytorch/speedup/compress_modules.py
+1
-0
nni/compression/pytorch/speedup/jit_translate.py
nni/compression/pytorch/speedup/jit_translate.py
+40
-1
No files found.
nni/compression/pytorch/speedup/compress_modules.py
View file @
6aaa2d92
...
...
@@ -16,6 +16,7 @@ replace_module = {
'MaxPool2d'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'AvgPool2d'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'AdaptiveAvgPool2d'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'ZeroPad2d'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'ReLU'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'ReLU6'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'LeakyReLU'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
...
...
nni/compression/pytorch/speedup/jit_translate.py
View file @
6aaa2d92
...
...
@@ -142,6 +142,29 @@ def add_python(node, speedup):
return
new_add
def
sub_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
constant
=
[
None
,
None
]
for
i
in
range
(
2
):
input_i
=
inputs
[
i
]
debug_name
=
input_i
.
debugName
()
if
debug_name
not
in
speedup
.
internal_result
:
# this input is a constant value
# TODO: what if this input is a constant tensor
if
input_i
.
toIValue
()
is
not
None
:
constant
[
i
]
=
parse_constant
(
input_i
,
speedup
)
break
if
constant
[
0
]
is
None
and
constant
[
1
]
is
None
:
new_sub
=
torch
.
sub
elif
constant
[
0
]
is
not
None
:
new_sub
=
partial
(
torch
.
sub
,
input
=
constant
)
else
:
new_sub
=
partial
(
torch
.
sub
,
other
=
constant
)
return
new_sub
def
floor_div_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
...
...
@@ -228,6 +251,10 @@ def gelu_python(node, speedup):
return
torch
.
nn
.
GELU
()
def
silu_python
(
node
,
speedup
):
return
torch
.
nn
.
SiLU
()
def
avgpool2d_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
...
...
@@ -277,6 +304,14 @@ def unsqueeze_python(node, speedup):
new_unsqueeze
=
partial
(
torch
.
unsqueeze
,
dim
=
dim
)
return
new_unsqueeze
def
constant_pad_nd_python
(
node
,
speedup
):
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
pad
=
translate_list
(
inputs
[
1
],
speedup
)
value
=
parse_constant
(
inputs
[
2
],
speedup
)
new_constant_pad_nd
=
partial
(
torch
.
nn
.
functional
.
pad
,
pad
=
pad
,
value
=
value
)
return
new_constant_pad_nd
##########################################################
# Split Line
# Following module/functions cannot be translated into a
...
...
@@ -379,7 +414,7 @@ def reshape_python(node, speedup):
logger
.
info
(
'Reshape Module output size: %s'
,
str
(
self
.
shape
))
def
forward
(
self
,
*
args
):
return
args
[
0
].
view
(
self
.
shape
)
return
args
[
0
].
reshape
(
self
.
shape
)
c_node
=
node
.
key_node
inputs
=
list
(
c_node
.
inputs
())
shape
=
translate_list
(
inputs
[
1
],
speedup
)
...
...
@@ -505,6 +540,8 @@ def cat_python(node, speedup):
trans_from_jit_to_python
=
{
'aten::add'
:
add_python
,
'aten::add_'
:
add_python
,
'aten::sub'
:
sub_python
,
'aten::sub_'
:
sub_python
,
'aten::mul'
:
mul_python
,
'aten::mul_'
:
mul_python
,
'aten::relu'
:
relu_python
,
...
...
@@ -542,6 +579,8 @@ trans_from_jit_to_python = {
'aten::exp'
:
exp_python
,
'aten::squeeze'
:
squeeze_python
,
'aten::unsqueeze'
:
unsqueeze_python
,
'aten::constant_pad_nd'
:
constant_pad_nd_python
,
'aten::silu'
:
silu_python
,
'prim::TupleUnpack'
:
tupleunpack_python
,
'prim::ListUnpack'
:
tupleunpack_python
,
'prim::NumToTensor'
:
num2tensor_python
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment