Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
aa1f71c8
"src/include/ConstantMatrixDescriptor.hpp" did not exist on "f35c64eb78af4754e78f8746c8e28d2ac8b68e80"
Unverified
Commit
aa1f71c8
authored
May 20, 2022
by
J-shang
Committed by
GitHub
May 20, 2022
Browse files
[Compression] support expand_as (#4852)
parent
39ec21ca
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
3 deletions
+10
-3
nni/common/graph_utils.py
nni/common/graph_utils.py
+1
-1
nni/compression/pytorch/speedup/jit_translate.py
nni/compression/pytorch/speedup/jit_translate.py
+8
-1
nni/compression/pytorch/utils/shape_dependency.py
nni/compression/pytorch/utils/shape_dependency.py
+1
-1
No files found.
nni/common/graph_utils.py
View file @
aa1f71c8
...
@@ -775,7 +775,7 @@ class TorchModuleGraph(TorchGraph):
...
@@ -775,7 +775,7 @@ class TorchModuleGraph(TorchGraph):
"""
"""
# extract the input & output shape for the view and flatten
# extract the input & output shape for the view and flatten
for
node_group
in
self
.
nodes_py
.
nodes_op
:
for
node_group
in
self
.
nodes_py
.
nodes_op
:
if
node_group
.
op_type
in
[
'aten::view'
,
'aten::flatten'
,
'aten::mean'
,
'aten::reshape'
]:
if
node_group
.
op_type
in
[
'aten::view'
,
'aten::flatten'
,
'aten::mean'
,
'aten::reshape'
,
'aten::expand_as'
]:
# get shape infor for view (aten::view) func
# get shape infor for view (aten::view) func
cpp_node
=
list
(
filter
(
lambda
x
:
x
.
kind
()
==
node_group
.
op_type
,
cpp_node
=
list
(
filter
(
lambda
x
:
x
.
kind
()
==
node_group
.
op_type
,
node_group
.
node_cpps
))[
0
]
node_group
.
node_cpps
))[
0
]
...
...
nni/compression/pytorch/speedup/jit_translate.py
View file @
aa1f71c8
...
@@ -537,6 +537,13 @@ def cat_python(node, speedup):
...
@@ -537,6 +537,13 @@ def cat_python(node, speedup):
return
CatModule
(
dim
)
return
CatModule
(
dim
)
def
expandas_python
(
node
,
speedup
):
class
ExpandasModule
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
return
x
.
expand_as
(
y
).
clone
()
return
ExpandasModule
()
trans_from_jit_to_python
=
{
trans_from_jit_to_python
=
{
'aten::add'
:
add_python
,
'aten::add'
:
add_python
,
'aten::add_'
:
add_python
,
'aten::add_'
:
add_python
,
...
@@ -581,11 +588,11 @@ trans_from_jit_to_python = {
...
@@ -581,11 +588,11 @@ trans_from_jit_to_python = {
'aten::unsqueeze'
:
unsqueeze_python
,
'aten::unsqueeze'
:
unsqueeze_python
,
'aten::constant_pad_nd'
:
constant_pad_nd_python
,
'aten::constant_pad_nd'
:
constant_pad_nd_python
,
'aten::silu'
:
silu_python
,
'aten::silu'
:
silu_python
,
'aten::expand_as'
:
expandas_python
,
'prim::TupleUnpack'
:
tupleunpack_python
,
'prim::TupleUnpack'
:
tupleunpack_python
,
'prim::ListUnpack'
:
tupleunpack_python
,
'prim::ListUnpack'
:
tupleunpack_python
,
'prim::NumToTensor'
:
num2tensor_python
,
'prim::NumToTensor'
:
num2tensor_python
,
'prim::GetAttr'
:
getattr_python
'prim::GetAttr'
:
getattr_python
}
}
...
...
nni/compression/pytorch/utils/shape_dependency.py
View file @
aa1f71c8
...
@@ -20,7 +20,7 @@ MUL_TYPES = ['aten::mul', 'atem::mul_']
...
@@ -20,7 +20,7 @@ MUL_TYPES = ['aten::mul', 'atem::mul_']
CAT_TYPE
=
'aten::cat'
CAT_TYPE
=
'aten::cat'
logger
=
logging
.
getLogger
(
'Shape_Dependency'
)
logger
=
logging
.
getLogger
(
'Shape_Dependency'
)
RESHAPE_OPS
=
[
CAT_TYPE
,
'aten::view'
,
RESHAPE_OPS
=
[
CAT_TYPE
,
'aten::view'
,
'aten::reshape'
,
'aten::flatten'
,
'aten::mean'
]
'aten::reshape'
,
'aten::flatten'
,
'aten::mean'
,
'aten::expand_as'
]
def
lcm_list
(
L
):
def
lcm_list
(
L
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment