Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
63f313bf
Unverified
Commit
63f313bf
authored
Oct 25, 2021
by
Yuge Zhang
Committed by
GitHub
Oct 25, 2021
Browse files
[Retiarii] Fix mul match in operation def (#4262)
parent
0efabe96
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
0 deletions
+22
-0
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+7
-0
test/ut/retiarii/test_convert_models.py
test/ut/retiarii/test_convert_models.py
+15
-0
No files found.
nni/retiarii/operation_def/torch_op_def.py
View file @
63f313bf
...
@@ -254,6 +254,13 @@ class AtenFloordiv(PyTorchOperation):
...
@@ -254,6 +254,13 @@ class AtenFloordiv(PyTorchOperation):
return
f
'
{
output
}
=
{
inputs
[
0
]
}
//
{
inputs
[
1
]
}
'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
//
{
inputs
[
1
]
}
'
class
AtenMul
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::mul'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
*
{
inputs
[
1
]
}
'
class
AtenLen
(
PyTorchOperation
):
class
AtenLen
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::len'
]
_ori_type_name
=
[
'aten::len'
]
...
...
test/ut/retiarii/test_convert_models.py
View file @
63f313bf
...
@@ -73,6 +73,7 @@ class TestModels(unittest.TestCase, ConvertMixin):
...
@@ -73,6 +73,7 @@ class TestModels(unittest.TestCase, ConvertMixin):
def
test_append_input_tensor
(
self
):
def
test_append_input_tensor
(
self
):
from
typing
import
List
from
typing
import
List
class
Net
(
nn
.
Module
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
,
num_nodes
):
def
__init__
(
self
,
num_nodes
):
super
().
__init__
()
super
().
__init__
()
...
@@ -80,6 +81,7 @@ class TestModels(unittest.TestCase, ConvertMixin):
...
@@ -80,6 +81,7 @@ class TestModels(unittest.TestCase, ConvertMixin):
self
.
num_nodes
=
num_nodes
self
.
num_nodes
=
num_nodes
for
_
in
range
(
num_nodes
):
for
_
in
range
(
num_nodes
):
self
.
ops
.
append
(
nn
.
Linear
(
16
,
16
))
self
.
ops
.
append
(
nn
.
Linear
(
16
,
16
))
def
forward
(
self
,
x
:
List
[
torch
.
Tensor
]):
def
forward
(
self
,
x
:
List
[
torch
.
Tensor
]):
state
=
x
state
=
x
for
ops
in
self
.
ops
:
for
ops
in
self
.
ops
:
...
@@ -90,6 +92,19 @@ class TestModels(unittest.TestCase, ConvertMixin):
...
@@ -90,6 +92,19 @@ class TestModels(unittest.TestCase, ConvertMixin):
x
=
torch
.
rand
((
1
,
16
),
dtype
=
torch
.
float
)
x
=
torch
.
rand
((
1
,
16
),
dtype
=
torch
.
float
)
self
.
run_test
(
model
,
([
x
],
))
self
.
run_test
(
model
,
([
x
],
))
def
test_channels_shuffle
(
self
):
class
Net
(
nn
.
Module
):
def
forward
(
self
,
x
):
bs
,
num_channels
,
height
,
width
=
x
.
size
()
x
=
x
.
reshape
(
bs
*
num_channels
//
2
,
2
,
height
*
width
)
x
=
x
.
permute
(
1
,
0
,
2
)
x
=
x
.
reshape
(
2
,
-
1
,
num_channels
//
2
,
height
,
width
)
return
x
[
0
],
x
[
1
]
model
=
Net
()
x
=
torch
.
rand
((
1
,
64
,
224
,
224
),
dtype
=
torch
.
float
)
self
.
run_test
(
model
,
(
x
,
))
def
test_identity_node
(
self
):
def
test_identity_node
(
self
):
class
Net
(
nn
.
Module
):
class
Net
(
nn
.
Module
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment