Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bb397480
"docs/en_US/TrialExample/OpEvoExamples.rst" did not exist on "dbb2434f5d2d976be26b594342a68cb46619ecea"
Unverified
Commit
bb397480
authored
Jul 20, 2021
by
kalineid
Committed by
GitHub
Jul 20, 2021
Browse files
Add tests for GraphConverterWithShape (#3951)
parent
403195f0
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
154 additions
and
40 deletions
+154
-40
test/ut/retiarii/convert_mixin.py
test/ut/retiarii/convert_mixin.py
+19
-0
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+7
-4
test/ut/retiarii/test_convert_basic.py
test/ut/retiarii/test_convert_basic.py
+14
-10
test/ut/retiarii/test_convert_models.py
test/ut/retiarii/test_convert_models.py
+7
-4
test/ut/retiarii/test_convert_operators.py
test/ut/retiarii/test_convert_operators.py
+20
-17
test/ut/retiarii/test_convert_pytorch.py
test/ut/retiarii/test_convert_pytorch.py
+8
-5
test/ut/retiarii/test_convert_shape.py
test/ut/retiarii/test_convert_shape.py
+79
-0
No files found.
test/ut/retiarii/convert_mixin.py
0 → 100644
View file @
bb397480
import
torch
from
nni.retiarii.converter.graph_gen
import
convert_to_graph
,
GraphConverterWithShape
class
ConvertMixin
:
@
staticmethod
def
_convert_model
(
model
,
input
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
return
model_ir
class
ConvertWithShapeMixin
:
@
staticmethod
def
_convert_model
(
model
,
input
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
,
converter
=
GraphConverterWithShape
(),
example_inputs
=
input
)
return
model_ir
test/ut/retiarii/test_convert.py
View file @
bb397480
...
@@ -13,9 +13,10 @@ import torchvision
...
@@ -13,9 +13,10 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
basic_unit
from
nni.retiarii
import
basic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
MnistNet
(
nn
.
Module
):
class
MnistNet
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
MnistNet
,
self
).
__init__
()
super
(
MnistNet
,
self
).
__init__
()
...
@@ -48,7 +49,7 @@ class Linear(nn.Module):
...
@@ -48,7 +49,7 @@ class Linear(nn.Module):
out
=
self
.
linear
(
input
.
view
(
size
[
0
]
*
size
[
1
],
-
1
))
out
=
self
.
linear
(
input
.
view
(
size
[
0
]
*
size
[
1
],
-
1
))
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
class
TestConvert
(
unittest
.
TestCase
):
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
...
@@ -61,8 +62,7 @@ class TestConvert(unittest.TestCase):
...
@@ -61,8 +62,7 @@ class TestConvert(unittest.TestCase):
return
result
return
result
def
checkExportImport
(
self
,
model
,
input
):
def
checkExportImport
(
self
,
model
,
input
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
exec_vars
=
{}
exec_vars
=
{}
...
@@ -579,3 +579,6 @@ class TestConvert(unittest.TestCase):
...
@@ -579,3 +579,6 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
model
,
(
x
,))
self
.
checkExportImport
(
model
,
(
x
,))
finally
:
finally
:
remove_inject_pytorch_nn
()
remove_inject_pytorch_nn
()
class
TestConvertWithShape
(
TestConvert
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_basic.py
View file @
bb397480
...
@@ -9,12 +9,13 @@ import torchvision
...
@@ -9,12 +9,13 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
basic_unit
from
nni.retiarii
import
basic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
# following pytorch v1.7.1
# following pytorch v1.7.1
class
TestConvert
(
unittest
.
TestCase
):
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
...
@@ -27,8 +28,7 @@ class TestConvert(unittest.TestCase):
...
@@ -27,8 +28,7 @@ class TestConvert(unittest.TestCase):
return
result
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
print
(
model_code
)
...
@@ -280,3 +280,7 @@ class TestConvert(unittest.TestCase):
...
@@ -280,3 +280,7 @@ class TestConvert(unittest.TestCase):
out1
=
x
.
ceil
()
out1
=
x
.
ceil
()
return
out1
return
out1
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
),
))
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
),
))
class
TestConvertWithShape
(
TestConvert
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_models.py
View file @
bb397480
...
@@ -10,11 +10,12 @@ import torchvision
...
@@ -10,11 +10,12 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
serialize
from
nni.retiarii
import
serialize
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
TestModels
(
unittest
.
TestCase
):
class
TestModels
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
...
@@ -27,8 +28,7 @@ class TestModels(unittest.TestCase):
...
@@ -27,8 +28,7 @@ class TestModels(unittest.TestCase):
return
result
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
print
(
model_code
)
...
@@ -89,3 +89,6 @@ class TestModels(unittest.TestCase):
...
@@ -89,3 +89,6 @@ class TestModels(unittest.TestCase):
model
=
Net
(
4
)
model
=
Net
(
4
)
x
=
torch
.
rand
((
1
,
16
),
dtype
=
torch
.
float
)
x
=
torch
.
rand
((
1
,
16
),
dtype
=
torch
.
float
)
self
.
run_test
(
model
,
([
x
],
))
self
.
run_test
(
model
,
([
x
],
))
class
TestModelsWithShape
(
TestModels
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_operators.py
View file @
bb397480
...
@@ -15,13 +15,14 @@ import torch.nn.functional as F
...
@@ -15,13 +15,14 @@ import torch.nn.functional as F
import
torchvision
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
# following pytorch v1.7.1
# following pytorch v1.7.1
class
TestOperators
(
unittest
.
TestCase
):
class
TestOperators
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
...
@@ -34,8 +35,7 @@ class TestOperators(unittest.TestCase):
...
@@ -34,8 +35,7 @@ class TestOperators(unittest.TestCase):
return
result
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
#print(model_code)
#print(model_code)
...
@@ -1386,3 +1386,6 @@ class TestOperators(unittest.TestCase):
...
@@ -1386,3 +1386,6 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
20
,
5
,
10
,
10
)
x
=
torch
.
randn
(
20
,
5
,
10
,
10
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
class
TestOperatorsWithShape
(
TestOperators
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_pytorch.py
View file @
bb397480
...
@@ -15,11 +15,12 @@ import torchvision
...
@@ -15,11 +15,12 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
serialize
from
nni.retiarii
import
serialize
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
TestPytorch
(
unittest
.
TestCase
):
class
TestPytorch
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
...
@@ -32,8 +33,7 @@ class TestPytorch(unittest.TestCase):
...
@@ -32,8 +33,7 @@ class TestPytorch(unittest.TestCase):
return
result
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
print
(
model_code
)
...
@@ -1231,3 +1231,6 @@ class TestPytorch(unittest.TestCase):
...
@@ -1231,3 +1231,6 @@ class TestPytorch(unittest.TestCase):
x
=
torch
.
randn
(
5
,
3
,
2
)
x
=
torch
.
randn
(
5
,
3
,
2
)
self
.
run_test
(
SizeModel
(
10
,
5
),
(
x
,
))
self
.
run_test
(
SizeModel
(
10
,
5
),
(
x
,
))
class
TestPytorchWithShape
(
TestPytorch
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_shape.py
0 → 100644
View file @
bb397480
import
unittest
import
torch
import
nni.retiarii.nn.pytorch
as
nn
from
.convert_mixin
import
ConvertWithShapeMixin
class
TestShape
(
unittest
.
TestCase
,
ConvertWithShapeMixin
):
def
test_simple_convnet
(
self
):
class
ConvNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
1
,
3
)
self
.
relu
=
nn
.
ReLU
()
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
)
def
forward
(
self
,
x
):
return
self
.
pool
(
self
.
relu
(
self
.
conv
(
x
)))
net
=
ConvNet
()
input
=
torch
.
randn
((
1
,
3
,
224
,
224
))
model_ir
=
self
.
_convert_model
(
net
,
input
)
conv_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)[
0
]
relu_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.activation.ReLU'
)[
0
]
pool_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.pooling.MaxPool2d'
)[
0
]
self
.
assertEqual
(
conv_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
conv_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
111
,
111
]])
def
test_nested_module
(
self
):
class
ConvRelu
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
1
,
3
)
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
return
self
.
relu
(
self
.
conv
(
x
))
class
ConvNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
ConvRelu
()
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
)
def
forward
(
self
,
x
):
return
self
.
pool
(
self
.
conv
(
x
))
net
=
ConvNet
()
input
=
torch
.
randn
((
1
,
3
,
224
,
224
))
model_ir
=
self
.
_convert_model
(
net
,
input
)
# check if shape propagation works
cell_node
=
model_ir
.
get_nodes_by_type
(
'_cell'
)[
0
]
self
.
assertEqual
(
cell_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
cell_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
def
test_layerchoice
(
self
):
class
ConvNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
LayerChoice
([
nn
.
Conv2d
(
3
,
1
,
3
),
nn
.
Conv2d
(
3
,
1
,
5
,
padding
=
1
),
])
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
)
def
forward
(
self
,
x
):
return
self
.
pool
(
self
.
conv
(
x
))
net
=
ConvNet
()
input
=
torch
.
randn
((
1
,
3
,
224
,
224
))
model_ir
=
self
.
_convert_model
(
net
,
input
)
# check shape info of each candidates
conv_nodes
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)
self
.
assertEqual
(
conv_nodes
[
0
].
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
conv_nodes
[
1
].
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment