Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bb397480
Unverified
Commit
bb397480
authored
Jul 20, 2021
by
kalineid
Committed by
GitHub
Jul 20, 2021
Browse files
Add tests for GraphConverterWithShape (#3951)
parent
403195f0
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
154 additions
and
40 deletions
+154
-40
test/ut/retiarii/convert_mixin.py
test/ut/retiarii/convert_mixin.py
+19
-0
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+7
-4
test/ut/retiarii/test_convert_basic.py
test/ut/retiarii/test_convert_basic.py
+14
-10
test/ut/retiarii/test_convert_models.py
test/ut/retiarii/test_convert_models.py
+7
-4
test/ut/retiarii/test_convert_operators.py
test/ut/retiarii/test_convert_operators.py
+20
-17
test/ut/retiarii/test_convert_pytorch.py
test/ut/retiarii/test_convert_pytorch.py
+8
-5
test/ut/retiarii/test_convert_shape.py
test/ut/retiarii/test_convert_shape.py
+79
-0
No files found.
test/ut/retiarii/convert_mixin.py
0 → 100644
View file @
bb397480
import
torch
from
nni.retiarii.converter.graph_gen
import
convert_to_graph
,
GraphConverterWithShape
class
ConvertMixin
:
@
staticmethod
def
_convert_model
(
model
,
input
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
return
model_ir
class
ConvertWithShapeMixin
:
@
staticmethod
def
_convert_model
(
model
,
input
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
,
converter
=
GraphConverterWithShape
(),
example_inputs
=
input
)
return
model_ir
test/ut/retiarii/test_convert.py
View file @
bb397480
...
@@ -13,9 +13,10 @@ import torchvision
...
@@ -13,9 +13,10 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
basic_unit
from
nni.retiarii
import
basic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
MnistNet
(
nn
.
Module
):
class
MnistNet
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
MnistNet
,
self
).
__init__
()
super
(
MnistNet
,
self
).
__init__
()
...
@@ -48,7 +49,7 @@ class Linear(nn.Module):
...
@@ -48,7 +49,7 @@ class Linear(nn.Module):
out
=
self
.
linear
(
input
.
view
(
size
[
0
]
*
size
[
1
],
-
1
))
out
=
self
.
linear
(
input
.
view
(
size
[
0
]
*
size
[
1
],
-
1
))
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
class
TestConvert
(
unittest
.
TestCase
):
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
...
@@ -61,8 +62,7 @@ class TestConvert(unittest.TestCase):
...
@@ -61,8 +62,7 @@ class TestConvert(unittest.TestCase):
return
result
return
result
def
checkExportImport
(
self
,
model
,
input
):
def
checkExportImport
(
self
,
model
,
input
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
exec_vars
=
{}
exec_vars
=
{}
...
@@ -579,3 +579,6 @@ class TestConvert(unittest.TestCase):
...
@@ -579,3 +579,6 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
model
,
(
x
,))
self
.
checkExportImport
(
model
,
(
x
,))
finally
:
finally
:
remove_inject_pytorch_nn
()
remove_inject_pytorch_nn
()
class
TestConvertWithShape
(
TestConvert
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_basic.py
View file @
bb397480
...
@@ -9,12 +9,13 @@ import torchvision
...
@@ -9,12 +9,13 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
basic_unit
from
nni.retiarii
import
basic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
# following pytorch v1.7.1
# following pytorch v1.7.1
class
TestConvert
(
unittest
.
TestCase
):
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
...
@@ -27,8 +28,7 @@ class TestConvert(unittest.TestCase):
...
@@ -27,8 +28,7 @@ class TestConvert(unittest.TestCase):
return
result
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
print
(
model_code
)
...
@@ -280,3 +280,7 @@ class TestConvert(unittest.TestCase):
...
@@ -280,3 +280,7 @@ class TestConvert(unittest.TestCase):
out1
=
x
.
ceil
()
out1
=
x
.
ceil
()
return
out1
return
out1
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
),
))
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
),
))
class
TestConvertWithShape
(
TestConvert
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_models.py
View file @
bb397480
...
@@ -10,11 +10,12 @@ import torchvision
...
@@ -10,11 +10,12 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
serialize
from
nni.retiarii
import
serialize
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
TestModels
(
unittest
.
TestCase
):
class
TestModels
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
...
@@ -27,8 +28,7 @@ class TestModels(unittest.TestCase):
...
@@ -27,8 +28,7 @@ class TestModels(unittest.TestCase):
return
result
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
print
(
model_code
)
...
@@ -89,3 +89,6 @@ class TestModels(unittest.TestCase):
...
@@ -89,3 +89,6 @@ class TestModels(unittest.TestCase):
model
=
Net
(
4
)
model
=
Net
(
4
)
x
=
torch
.
rand
((
1
,
16
),
dtype
=
torch
.
float
)
x
=
torch
.
rand
((
1
,
16
),
dtype
=
torch
.
float
)
self
.
run_test
(
model
,
([
x
],
))
self
.
run_test
(
model
,
([
x
],
))
class
TestModelsWithShape
(
TestModels
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_operators.py
View file @
bb397480
...
@@ -15,13 +15,14 @@ import torch.nn.functional as F
...
@@ -15,13 +15,14 @@ import torch.nn.functional as F
import
torchvision
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
# following pytorch v1.7.1
# following pytorch v1.7.1
class
TestOperators
(
unittest
.
TestCase
):
class
TestOperators
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
...
@@ -34,8 +35,7 @@ class TestOperators(unittest.TestCase):
...
@@ -34,8 +35,7 @@ class TestOperators(unittest.TestCase):
return
result
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
#print(model_code)
#print(model_code)
...
@@ -1386,3 +1386,6 @@ class TestOperators(unittest.TestCase):
...
@@ -1386,3 +1386,6 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
20
,
5
,
10
,
10
)
x
=
torch
.
randn
(
20
,
5
,
10
,
10
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
class
TestOperatorsWithShape
(
TestOperators
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_pytorch.py
View file @
bb397480
...
@@ -15,11 +15,12 @@ import torchvision
...
@@ -15,11 +15,12 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
serialize
from
nni.retiarii
import
serialize
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
TestPytorch
(
unittest
.
TestCase
):
class
TestPytorch
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
...
@@ -32,8 +33,7 @@ class TestPytorch(unittest.TestCase):
...
@@ -32,8 +33,7 @@ class TestPytorch(unittest.TestCase):
return
result
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
print
(
model_code
)
...
@@ -1231,3 +1231,6 @@ class TestPytorch(unittest.TestCase):
...
@@ -1231,3 +1231,6 @@ class TestPytorch(unittest.TestCase):
x
=
torch
.
randn
(
5
,
3
,
2
)
x
=
torch
.
randn
(
5
,
3
,
2
)
self
.
run_test
(
SizeModel
(
10
,
5
),
(
x
,
))
self
.
run_test
(
SizeModel
(
10
,
5
),
(
x
,
))
class
TestPytorchWithShape
(
TestPytorch
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_shape.py
0 → 100644
View file @
bb397480
import
unittest
import
torch
import
nni.retiarii.nn.pytorch
as
nn
from
.convert_mixin
import
ConvertWithShapeMixin
class
TestShape
(
unittest
.
TestCase
,
ConvertWithShapeMixin
):
def
test_simple_convnet
(
self
):
class
ConvNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
1
,
3
)
self
.
relu
=
nn
.
ReLU
()
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
)
def
forward
(
self
,
x
):
return
self
.
pool
(
self
.
relu
(
self
.
conv
(
x
)))
net
=
ConvNet
()
input
=
torch
.
randn
((
1
,
3
,
224
,
224
))
model_ir
=
self
.
_convert_model
(
net
,
input
)
conv_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)[
0
]
relu_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.activation.ReLU'
)[
0
]
pool_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.pooling.MaxPool2d'
)[
0
]
self
.
assertEqual
(
conv_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
conv_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
111
,
111
]])
def
test_nested_module
(
self
):
class
ConvRelu
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
1
,
3
)
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
return
self
.
relu
(
self
.
conv
(
x
))
class
ConvNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
ConvRelu
()
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
)
def
forward
(
self
,
x
):
return
self
.
pool
(
self
.
conv
(
x
))
net
=
ConvNet
()
input
=
torch
.
randn
((
1
,
3
,
224
,
224
))
model_ir
=
self
.
_convert_model
(
net
,
input
)
# check if shape propagation works
cell_node
=
model_ir
.
get_nodes_by_type
(
'_cell'
)[
0
]
self
.
assertEqual
(
cell_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
cell_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
def
test_layerchoice
(
self
):
class
ConvNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
LayerChoice
([
nn
.
Conv2d
(
3
,
1
,
3
),
nn
.
Conv2d
(
3
,
1
,
5
,
padding
=
1
),
])
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
)
def
forward
(
self
,
x
):
return
self
.
pool
(
self
.
conv
(
x
))
net
=
ConvNet
()
input
=
torch
.
randn
((
1
,
3
,
224
,
224
))
model_ir
=
self
.
_convert_model
(
net
,
input
)
# check shape info of each candidates
conv_nodes
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)
self
.
assertEqual
(
conv_nodes
[
0
].
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
conv_nodes
[
1
].
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment