Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bb397480
Unverified
Commit
bb397480
authored
Jul 20, 2021
by
kalineid
Committed by
GitHub
Jul 20, 2021
Browse files
Add tests for GraphConverterWithShape (#3951)
parent
403195f0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
154 additions
and
40 deletions
+154
-40
test/ut/retiarii/convert_mixin.py
test/ut/retiarii/convert_mixin.py
+19
-0
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+7
-4
test/ut/retiarii/test_convert_basic.py
test/ut/retiarii/test_convert_basic.py
+14
-10
test/ut/retiarii/test_convert_models.py
test/ut/retiarii/test_convert_models.py
+7
-4
test/ut/retiarii/test_convert_operators.py
test/ut/retiarii/test_convert_operators.py
+20
-17
test/ut/retiarii/test_convert_pytorch.py
test/ut/retiarii/test_convert_pytorch.py
+8
-5
test/ut/retiarii/test_convert_shape.py
test/ut/retiarii/test_convert_shape.py
+79
-0
No files found.
test/ut/retiarii/convert_mixin.py
0 → 100644
View file @
bb397480
import
torch
from
nni.retiarii.converter.graph_gen
import
convert_to_graph
,
GraphConverterWithShape
class
ConvertMixin
:
@
staticmethod
def
_convert_model
(
model
,
input
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
return
model_ir
class
ConvertWithShapeMixin
:
@
staticmethod
def
_convert_model
(
model
,
input
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
,
converter
=
GraphConverterWithShape
(),
example_inputs
=
input
)
return
model_ir
test/ut/retiarii/test_convert.py
View file @
bb397480
...
...
@@ -13,9 +13,10 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
basic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
MnistNet
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MnistNet
,
self
).
__init__
()
...
...
@@ -48,7 +49,7 @@ class Linear(nn.Module):
out
=
self
.
linear
(
input
.
view
(
size
[
0
]
*
size
[
1
],
-
1
))
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
class
TestConvert
(
unittest
.
TestCase
):
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
...
...
@@ -61,8 +62,7 @@ class TestConvert(unittest.TestCase):
return
result
def
checkExportImport
(
self
,
model
,
input
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_code
=
model_to_pytorch_script
(
model_ir
)
exec_vars
=
{}
...
...
@@ -579,3 +579,6 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
model
,
(
x
,))
finally
:
remove_inject_pytorch_nn
()
class
TestConvertWithShape
(
TestConvert
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_basic.py
View file @
bb397480
...
...
@@ -9,12 +9,13 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
basic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
from
nni.retiarii.codegen
import
model_to_pytorch_script
# following pytorch v1.7.1
class
TestConvert
(
unittest
.
TestCase
):
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
...
...
@@ -27,8 +28,7 @@ class TestConvert(unittest.TestCase):
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
...
...
@@ -188,7 +188,7 @@ class TestConvert(unittest.TestCase):
out2
=
torch
.
addmv
(
x
,
y
,
z
,
beta
=
0.1
,
alpha
=
0.2
)
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
2
),
torch
.
randn
(
2
,
3
),
torch
.
randn
(
3
),
))
def
test_basic_addr
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
...
...
@@ -204,7 +204,7 @@ class TestConvert(unittest.TestCase):
out2
=
torch
.
allclose
(
x
,
y
,
rtol
=
1e-05
,
atol
=
1e-08
,
equal_nan
=
False
)
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
10000.
,
1e-07
]),
torch
.
tensor
([
10000.1
,
1e-08
]),
))
def
test_basic_angle
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
...
...
@@ -229,7 +229,7 @@ class TestConvert(unittest.TestCase):
o4
=
x
.
argmin
(
dim
=
1
,
keepdim
=
True
)
return
out1
,
out2
,
out3
,
out4
,
out5
,
o1
,
o2
,
o3
,
o4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
,
4
),
))
def
test_basic_argsort
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
...
...
@@ -241,7 +241,7 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
,
4
),
))
# skip backward(gradient=None, retain_graph=None, create_graph=False)
def
test_basic_bernoulli
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
...
...
@@ -261,7 +261,7 @@ class TestConvert(unittest.TestCase):
out4
=
x
.
bincount
(
weights
=
y
,
minlength
=
2
)
return
out1
,
out2
,
out3
,
out4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randint
(
0
,
8
,
(
5
,),
dtype
=
torch
.
int64
),
torch
.
linspace
(
0
,
1
,
steps
=
5
),
))
def
test_basic_bitwise
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
...
...
@@ -279,4 +279,8 @@ class TestConvert(unittest.TestCase):
def
forward
(
self
,
x
):
out1
=
x
.
ceil
()
return
out1
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
),
))
\ No newline at end of file
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
),
))
class
TestConvertWithShape
(
TestConvert
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_models.py
View file @
bb397480
...
...
@@ -10,11 +10,12 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
serialize
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
TestModels
(
unittest
.
TestCase
):
class
TestModels
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
...
...
@@ -27,8 +28,7 @@ class TestModels(unittest.TestCase):
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
...
...
@@ -89,3 +89,6 @@ class TestModels(unittest.TestCase):
model
=
Net
(
4
)
x
=
torch
.
rand
((
1
,
16
),
dtype
=
torch
.
float
)
self
.
run_test
(
model
,
([
x
],
))
class
TestModelsWithShape
(
TestModels
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_operators.py
View file @
bb397480
...
...
@@ -15,13 +15,14 @@ import torch.nn.functional as F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
# following pytorch v1.7.1
class
TestOperators
(
unittest
.
TestCase
):
class
TestOperators
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
...
...
@@ -34,8 +35,7 @@ class TestOperators(unittest.TestCase):
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_code
=
model_to_pytorch_script
(
model_ir
)
#print(model_code)
...
...
@@ -1042,7 +1042,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
tensor
([[[[
0.0
,
1.0
,
1.0
,
1.0
],
[
2.0
,
3.0
,
7.0
,
7.0
]]]],
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_batchnorm
(
self
):
class
SimpleOp
(
nn
.
Module
):
...
...
@@ -1056,7 +1056,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
ones
(
2
,
2
,
2
,
2
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_batchnorm_1d
(
self
):
class
SimpleOp
(
nn
.
Module
):
...
...
@@ -1084,7 +1084,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
ones
(
20
,
16
,
50
,
40
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_conv_onnx_irv4_opset8
(
self
):
# This test point checks that for opset 8 (or lower), even if
# keep_initializers_as_inputs is set to False, it is ignored,
...
...
@@ -1129,7 +1129,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
20
,
16
,
50
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_maxpool_dilations
(
self
):
class
SimpleOp
(
nn
.
Module
):
...
...
@@ -1143,7 +1143,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
20
,
16
,
50
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_avg_pool2d
(
self
):
class
SimpleOp
(
nn
.
Module
):
...
...
@@ -1157,7 +1157,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
20
,
16
,
50
,
32
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
@
unittest
.
skip
(
'jit error: "Return value was annotated as having type Tensor but is actually of type Tuple[Tensor, Tensor]"'
)
def
test_basic_maxpool_indices
(
self
):
class
SimpleOp
(
nn
.
Module
):
...
...
@@ -1200,7 +1200,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_elu
(
self
):
class
SimpleOp
(
nn
.
Module
):
...
...
@@ -1214,7 +1214,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_selu
(
self
):
class
SimpleOp
(
nn
.
Module
):
...
...
@@ -1261,7 +1261,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
128
,
128
,
1
,
1
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_embedding_bags
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
...
...
@@ -1288,7 +1288,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_prelu
(
self
):
class
SimpleOp
(
nn
.
Module
):
...
...
@@ -1302,7 +1302,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_log_sigmoid
(
self
):
class
SimpleOp
(
nn
.
Module
):
...
...
@@ -1316,7 +1316,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_linear
(
self
):
class
SimpleOp
(
nn
.
Module
):
...
...
@@ -1385,4 +1385,7 @@ class TestOperators(unittest.TestCase):
return
out
x
=
torch
.
randn
(
20
,
5
,
10
,
10
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
\ No newline at end of file
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
class
TestOperatorsWithShape
(
TestOperators
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_pytorch.py
View file @
bb397480
...
...
@@ -15,11 +15,12 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
serialize
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
TestPytorch
(
unittest
.
TestCase
):
class
TestPytorch
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
...
...
@@ -32,8 +33,7 @@ class TestPytorch(unittest.TestCase):
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
...
...
@@ -1230,4 +1230,7 @@ class TestPytorch(unittest.TestCase):
return
torch
.
arange
(
input
.
size
(
0
)),
torch
.
arange
(
input
.
size
(
-
1
)),
torch
.
ones
(
input
.
shape
)
x
=
torch
.
randn
(
5
,
3
,
2
)
self
.
run_test
(
SizeModel
(
10
,
5
),
(
x
,
))
\ No newline at end of file
self
.
run_test
(
SizeModel
(
10
,
5
),
(
x
,
))
class
TestPytorchWithShape
(
TestPytorch
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_shape.py
0 → 100644
View file @
bb397480
import
unittest
import
torch
import
nni.retiarii.nn.pytorch
as
nn
from
.convert_mixin
import
ConvertWithShapeMixin
class
TestShape
(
unittest
.
TestCase
,
ConvertWithShapeMixin
):
def
test_simple_convnet
(
self
):
class
ConvNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
1
,
3
)
self
.
relu
=
nn
.
ReLU
()
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
)
def
forward
(
self
,
x
):
return
self
.
pool
(
self
.
relu
(
self
.
conv
(
x
)))
net
=
ConvNet
()
input
=
torch
.
randn
((
1
,
3
,
224
,
224
))
model_ir
=
self
.
_convert_model
(
net
,
input
)
conv_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)[
0
]
relu_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.activation.ReLU'
)[
0
]
pool_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.pooling.MaxPool2d'
)[
0
]
self
.
assertEqual
(
conv_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
conv_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
111
,
111
]])
def
test_nested_module
(
self
):
class
ConvRelu
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
1
,
3
)
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
return
self
.
relu
(
self
.
conv
(
x
))
class
ConvNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
ConvRelu
()
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
)
def
forward
(
self
,
x
):
return
self
.
pool
(
self
.
conv
(
x
))
net
=
ConvNet
()
input
=
torch
.
randn
((
1
,
3
,
224
,
224
))
model_ir
=
self
.
_convert_model
(
net
,
input
)
# check if shape propagation works
cell_node
=
model_ir
.
get_nodes_by_type
(
'_cell'
)[
0
]
self
.
assertEqual
(
cell_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
cell_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
def
test_layerchoice
(
self
):
class
ConvNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
LayerChoice
([
nn
.
Conv2d
(
3
,
1
,
3
),
nn
.
Conv2d
(
3
,
1
,
5
,
padding
=
1
),
])
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
)
def
forward
(
self
,
x
):
return
self
.
pool
(
self
.
conv
(
x
))
net
=
ConvNet
()
input
=
torch
.
randn
((
1
,
3
,
224
,
224
))
model_ir
=
self
.
_convert_model
(
net
,
input
)
# check shape info of each candidates
conv_nodes
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)
self
.
assertEqual
(
conv_nodes
[
0
].
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
conv_nodes
[
1
].
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment