Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bb397480
"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "ee390c0b552e2af13a52672a667ff75a934aab6d"
Unverified
Commit
bb397480
authored
Jul 20, 2021
by
kalineid
Committed by
GitHub
Jul 20, 2021
Browse files
Add tests for GraphConverterWithShape (#3951)
parent
403195f0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
154 additions
and
40 deletions
+154
-40
test/ut/retiarii/convert_mixin.py
test/ut/retiarii/convert_mixin.py
+19
-0
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+7
-4
test/ut/retiarii/test_convert_basic.py
test/ut/retiarii/test_convert_basic.py
+14
-10
test/ut/retiarii/test_convert_models.py
test/ut/retiarii/test_convert_models.py
+7
-4
test/ut/retiarii/test_convert_operators.py
test/ut/retiarii/test_convert_operators.py
+20
-17
test/ut/retiarii/test_convert_pytorch.py
test/ut/retiarii/test_convert_pytorch.py
+8
-5
test/ut/retiarii/test_convert_shape.py
test/ut/retiarii/test_convert_shape.py
+79
-0
No files found.
test/ut/retiarii/convert_mixin.py
0 → 100644
View file @
bb397480
import
torch
from
nni.retiarii.converter.graph_gen
import
convert_to_graph
,
GraphConverterWithShape
class
ConvertMixin
:
@
staticmethod
def
_convert_model
(
model
,
input
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
return
model_ir
class
ConvertWithShapeMixin
:
@
staticmethod
def
_convert_model
(
model
,
input
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
,
converter
=
GraphConverterWithShape
(),
example_inputs
=
input
)
return
model_ir
test/ut/retiarii/test_convert.py
View file @
bb397480
...
@@ -13,9 +13,10 @@ import torchvision
...
@@ -13,9 +13,10 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
basic_unit
from
nni.retiarii
import
basic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
MnistNet
(
nn
.
Module
):
class
MnistNet
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
MnistNet
,
self
).
__init__
()
super
(
MnistNet
,
self
).
__init__
()
...
@@ -48,7 +49,7 @@ class Linear(nn.Module):
...
@@ -48,7 +49,7 @@ class Linear(nn.Module):
out
=
self
.
linear
(
input
.
view
(
size
[
0
]
*
size
[
1
],
-
1
))
out
=
self
.
linear
(
input
.
view
(
size
[
0
]
*
size
[
1
],
-
1
))
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
class
TestConvert
(
unittest
.
TestCase
):
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
...
@@ -61,8 +62,7 @@ class TestConvert(unittest.TestCase):
...
@@ -61,8 +62,7 @@ class TestConvert(unittest.TestCase):
return
result
return
result
def
checkExportImport
(
self
,
model
,
input
):
def
checkExportImport
(
self
,
model
,
input
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
exec_vars
=
{}
exec_vars
=
{}
...
@@ -579,3 +579,6 @@ class TestConvert(unittest.TestCase):
...
@@ -579,3 +579,6 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
model
,
(
x
,))
self
.
checkExportImport
(
model
,
(
x
,))
finally
:
finally
:
remove_inject_pytorch_nn
()
remove_inject_pytorch_nn
()
class
TestConvertWithShape
(
TestConvert
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_basic.py
View file @
bb397480
...
@@ -9,12 +9,13 @@ import torchvision
...
@@ -9,12 +9,13 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
basic_unit
from
nni.retiarii
import
basic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
# following pytorch v1.7.1
# following pytorch v1.7.1
class
TestConvert
(
unittest
.
TestCase
):
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
...
@@ -27,8 +28,7 @@ class TestConvert(unittest.TestCase):
...
@@ -27,8 +28,7 @@ class TestConvert(unittest.TestCase):
return
result
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
print
(
model_code
)
...
@@ -188,7 +188,7 @@ class TestConvert(unittest.TestCase):
...
@@ -188,7 +188,7 @@ class TestConvert(unittest.TestCase):
out2
=
torch
.
addmv
(
x
,
y
,
z
,
beta
=
0.1
,
alpha
=
0.2
)
out2
=
torch
.
addmv
(
x
,
y
,
z
,
beta
=
0.1
,
alpha
=
0.2
)
return
out1
,
out2
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
2
),
torch
.
randn
(
2
,
3
),
torch
.
randn
(
3
),
))
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
2
),
torch
.
randn
(
2
,
3
),
torch
.
randn
(
3
),
))
def
test_basic_addr
(
self
):
def
test_basic_addr
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
def
forward
(
self
,
x
,
y
,
z
):
...
@@ -204,7 +204,7 @@ class TestConvert(unittest.TestCase):
...
@@ -204,7 +204,7 @@ class TestConvert(unittest.TestCase):
out2
=
torch
.
allclose
(
x
,
y
,
rtol
=
1e-05
,
atol
=
1e-08
,
equal_nan
=
False
)
out2
=
torch
.
allclose
(
x
,
y
,
rtol
=
1e-05
,
atol
=
1e-08
,
equal_nan
=
False
)
return
out1
,
out2
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
10000.
,
1e-07
]),
torch
.
tensor
([
10000.1
,
1e-08
]),
))
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
10000.
,
1e-07
]),
torch
.
tensor
([
10000.1
,
1e-08
]),
))
def
test_basic_angle
(
self
):
def
test_basic_angle
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -229,7 +229,7 @@ class TestConvert(unittest.TestCase):
...
@@ -229,7 +229,7 @@ class TestConvert(unittest.TestCase):
o4
=
x
.
argmin
(
dim
=
1
,
keepdim
=
True
)
o4
=
x
.
argmin
(
dim
=
1
,
keepdim
=
True
)
return
out1
,
out2
,
out3
,
out4
,
out5
,
o1
,
o2
,
o3
,
o4
return
out1
,
out2
,
out3
,
out4
,
out5
,
o1
,
o2
,
o3
,
o4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
,
4
),
))
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
,
4
),
))
def
test_basic_argsort
(
self
):
def
test_basic_argsort
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -241,7 +241,7 @@ class TestConvert(unittest.TestCase):
...
@@ -241,7 +241,7 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
,
4
),
))
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
,
4
),
))
# skip backward(gradient=None, retain_graph=None, create_graph=False)
# skip backward(gradient=None, retain_graph=None, create_graph=False)
def
test_basic_bernoulli
(
self
):
def
test_basic_bernoulli
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -261,7 +261,7 @@ class TestConvert(unittest.TestCase):
...
@@ -261,7 +261,7 @@ class TestConvert(unittest.TestCase):
out4
=
x
.
bincount
(
weights
=
y
,
minlength
=
2
)
out4
=
x
.
bincount
(
weights
=
y
,
minlength
=
2
)
return
out1
,
out2
,
out3
,
out4
return
out1
,
out2
,
out3
,
out4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randint
(
0
,
8
,
(
5
,),
dtype
=
torch
.
int64
),
torch
.
linspace
(
0
,
1
,
steps
=
5
),
))
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randint
(
0
,
8
,
(
5
,),
dtype
=
torch
.
int64
),
torch
.
linspace
(
0
,
1
,
steps
=
5
),
))
def
test_basic_bitwise
(
self
):
def
test_basic_bitwise
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
def
forward
(
self
,
x
,
y
):
...
@@ -279,4 +279,8 @@ class TestConvert(unittest.TestCase):
...
@@ -279,4 +279,8 @@ class TestConvert(unittest.TestCase):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
out1
=
x
.
ceil
()
out1
=
x
.
ceil
()
return
out1
return
out1
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
),
))
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
),
))
\ No newline at end of file
class
TestConvertWithShape
(
TestConvert
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_models.py
View file @
bb397480
...
@@ -10,11 +10,12 @@ import torchvision
...
@@ -10,11 +10,12 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
serialize
from
nni.retiarii
import
serialize
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
TestModels
(
unittest
.
TestCase
):
class
TestModels
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
...
@@ -27,8 +28,7 @@ class TestModels(unittest.TestCase):
...
@@ -27,8 +28,7 @@ class TestModels(unittest.TestCase):
return
result
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
print
(
model_code
)
...
@@ -89,3 +89,6 @@ class TestModels(unittest.TestCase):
...
@@ -89,3 +89,6 @@ class TestModels(unittest.TestCase):
model
=
Net
(
4
)
model
=
Net
(
4
)
x
=
torch
.
rand
((
1
,
16
),
dtype
=
torch
.
float
)
x
=
torch
.
rand
((
1
,
16
),
dtype
=
torch
.
float
)
self
.
run_test
(
model
,
([
x
],
))
self
.
run_test
(
model
,
([
x
],
))
class
TestModelsWithShape
(
TestModels
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_operators.py
View file @
bb397480
...
@@ -15,13 +15,14 @@ import torch.nn.functional as F
...
@@ -15,13 +15,14 @@ import torch.nn.functional as F
import
torchvision
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
# following pytorch v1.7.1
# following pytorch v1.7.1
class
TestOperators
(
unittest
.
TestCase
):
class
TestOperators
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
...
@@ -34,8 +35,7 @@ class TestOperators(unittest.TestCase):
...
@@ -34,8 +35,7 @@ class TestOperators(unittest.TestCase):
return
result
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
#print(model_code)
#print(model_code)
...
@@ -1042,7 +1042,7 @@ class TestOperators(unittest.TestCase):
...
@@ -1042,7 +1042,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
tensor
([[[[
0.0
,
1.0
,
1.0
,
1.0
],
[
2.0
,
3.0
,
7.0
,
7.0
]]]],
requires_grad
=
True
)
x
=
torch
.
tensor
([[[[
0.0
,
1.0
,
1.0
,
1.0
],
[
2.0
,
3.0
,
7.0
,
7.0
]]]],
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_batchnorm
(
self
):
def
test_basic_batchnorm
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
...
@@ -1056,7 +1056,7 @@ class TestOperators(unittest.TestCase):
...
@@ -1056,7 +1056,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
ones
(
2
,
2
,
2
,
2
,
requires_grad
=
True
)
x
=
torch
.
ones
(
2
,
2
,
2
,
2
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_batchnorm_1d
(
self
):
def
test_basic_batchnorm_1d
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
...
@@ -1084,7 +1084,7 @@ class TestOperators(unittest.TestCase):
...
@@ -1084,7 +1084,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
ones
(
20
,
16
,
50
,
40
,
requires_grad
=
True
)
x
=
torch
.
ones
(
20
,
16
,
50
,
40
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_conv_onnx_irv4_opset8
(
self
):
def
test_conv_onnx_irv4_opset8
(
self
):
# This test point checks that for opset 8 (or lower), even if
# This test point checks that for opset 8 (or lower), even if
# keep_initializers_as_inputs is set to False, it is ignored,
# keep_initializers_as_inputs is set to False, it is ignored,
...
@@ -1129,7 +1129,7 @@ class TestOperators(unittest.TestCase):
...
@@ -1129,7 +1129,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
20
,
16
,
50
)
x
=
torch
.
randn
(
20
,
16
,
50
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_maxpool_dilations
(
self
):
def
test_basic_maxpool_dilations
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
...
@@ -1143,7 +1143,7 @@ class TestOperators(unittest.TestCase):
...
@@ -1143,7 +1143,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
20
,
16
,
50
)
x
=
torch
.
randn
(
20
,
16
,
50
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_avg_pool2d
(
self
):
def
test_basic_avg_pool2d
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
...
@@ -1157,7 +1157,7 @@ class TestOperators(unittest.TestCase):
...
@@ -1157,7 +1157,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
20
,
16
,
50
,
32
)
x
=
torch
.
randn
(
20
,
16
,
50
,
32
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
@
unittest
.
skip
(
'jit error: "Return value was annotated as having type Tensor but is actually of type Tuple[Tensor, Tensor]"'
)
@
unittest
.
skip
(
'jit error: "Return value was annotated as having type Tensor but is actually of type Tuple[Tensor, Tensor]"'
)
def
test_basic_maxpool_indices
(
self
):
def
test_basic_maxpool_indices
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
...
@@ -1200,7 +1200,7 @@ class TestOperators(unittest.TestCase):
...
@@ -1200,7 +1200,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_elu
(
self
):
def
test_basic_elu
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
...
@@ -1214,7 +1214,7 @@ class TestOperators(unittest.TestCase):
...
@@ -1214,7 +1214,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_selu
(
self
):
def
test_basic_selu
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
...
@@ -1261,7 +1261,7 @@ class TestOperators(unittest.TestCase):
...
@@ -1261,7 +1261,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
128
,
128
,
1
,
1
,
requires_grad
=
True
)
x
=
torch
.
randn
(
128
,
128
,
1
,
1
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_embedding_bags
(
self
):
def
test_embedding_bags
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -1288,7 +1288,7 @@ class TestOperators(unittest.TestCase):
...
@@ -1288,7 +1288,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_prelu
(
self
):
def
test_basic_prelu
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
...
@@ -1302,7 +1302,7 @@ class TestOperators(unittest.TestCase):
...
@@ -1302,7 +1302,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_log_sigmoid
(
self
):
def
test_basic_log_sigmoid
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
...
@@ -1316,7 +1316,7 @@ class TestOperators(unittest.TestCase):
...
@@ -1316,7 +1316,7 @@ class TestOperators(unittest.TestCase):
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_linear
(
self
):
def
test_basic_linear
(
self
):
class
SimpleOp
(
nn
.
Module
):
class
SimpleOp
(
nn
.
Module
):
...
@@ -1385,4 +1385,7 @@ class TestOperators(unittest.TestCase):
...
@@ -1385,4 +1385,7 @@ class TestOperators(unittest.TestCase):
return
out
return
out
x
=
torch
.
randn
(
20
,
5
,
10
,
10
)
x
=
torch
.
randn
(
20
,
5
,
10
,
10
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
\ No newline at end of file
class
TestOperatorsWithShape
(
TestOperators
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_pytorch.py
View file @
bb397480
...
@@ -15,11 +15,12 @@ import torchvision
...
@@ -15,11 +15,12 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
serialize
from
nni.retiarii
import
serialize
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
TestPytorch
(
unittest
.
TestCase
):
class
TestPytorch
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
...
@@ -32,8 +33,7 @@ class TestPytorch(unittest.TestCase):
...
@@ -32,8 +33,7 @@ class TestPytorch(unittest.TestCase):
return
result
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
print
(
model_code
)
...
@@ -1230,4 +1230,7 @@ class TestPytorch(unittest.TestCase):
...
@@ -1230,4 +1230,7 @@ class TestPytorch(unittest.TestCase):
return
torch
.
arange
(
input
.
size
(
0
)),
torch
.
arange
(
input
.
size
(
-
1
)),
torch
.
ones
(
input
.
shape
)
return
torch
.
arange
(
input
.
size
(
0
)),
torch
.
arange
(
input
.
size
(
-
1
)),
torch
.
ones
(
input
.
shape
)
x
=
torch
.
randn
(
5
,
3
,
2
)
x
=
torch
.
randn
(
5
,
3
,
2
)
self
.
run_test
(
SizeModel
(
10
,
5
),
(
x
,
))
self
.
run_test
(
SizeModel
(
10
,
5
),
(
x
,
))
\ No newline at end of file
class
TestPytorchWithShape
(
TestPytorch
,
ConvertWithShapeMixin
):
pass
test/ut/retiarii/test_convert_shape.py
0 → 100644
View file @
bb397480
import
unittest
import
torch
import
nni.retiarii.nn.pytorch
as
nn
from
.convert_mixin
import
ConvertWithShapeMixin
class
TestShape
(
unittest
.
TestCase
,
ConvertWithShapeMixin
):
def
test_simple_convnet
(
self
):
class
ConvNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
1
,
3
)
self
.
relu
=
nn
.
ReLU
()
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
)
def
forward
(
self
,
x
):
return
self
.
pool
(
self
.
relu
(
self
.
conv
(
x
)))
net
=
ConvNet
()
input
=
torch
.
randn
((
1
,
3
,
224
,
224
))
model_ir
=
self
.
_convert_model
(
net
,
input
)
conv_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)[
0
]
relu_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.activation.ReLU'
)[
0
]
pool_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.pooling.MaxPool2d'
)[
0
]
self
.
assertEqual
(
conv_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
conv_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
111
,
111
]])
def
test_nested_module
(
self
):
class
ConvRelu
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
1
,
3
)
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
return
self
.
relu
(
self
.
conv
(
x
))
class
ConvNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
ConvRelu
()
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
)
def
forward
(
self
,
x
):
return
self
.
pool
(
self
.
conv
(
x
))
net
=
ConvNet
()
input
=
torch
.
randn
((
1
,
3
,
224
,
224
))
model_ir
=
self
.
_convert_model
(
net
,
input
)
# check if shape propagation works
cell_node
=
model_ir
.
get_nodes_by_type
(
'_cell'
)[
0
]
self
.
assertEqual
(
cell_node
.
operation
.
parameters
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
cell_node
.
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
def
test_layerchoice
(
self
):
class
ConvNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
LayerChoice
([
nn
.
Conv2d
(
3
,
1
,
3
),
nn
.
Conv2d
(
3
,
1
,
5
,
padding
=
1
),
])
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
)
def
forward
(
self
,
x
):
return
self
.
pool
(
self
.
conv
(
x
))
net
=
ConvNet
()
input
=
torch
.
randn
((
1
,
3
,
224
,
224
))
model_ir
=
self
.
_convert_model
(
net
,
input
)
# check shape info of each candidates
conv_nodes
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)
self
.
assertEqual
(
conv_nodes
[
0
].
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
conv_nodes
[
1
].
operation
.
parameters
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment