Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
0f88b86b
"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "126782b4336e11f1725435c2b6e0710883cdc9b1"
Unverified
Commit
0f88b86b
authored
Jan 05, 2021
by
Yuge Zhang
Committed by
GitHub
Jan 05, 2021
Browse files
Retiarii graph and code generation test (#3231)
parent
4fae3ed9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
582 additions
and
3 deletions
+582
-3
nni/retiarii/converter/op_types.py
nni/retiarii/converter/op_types.py
+4
-0
nni/retiarii/operation.py
nni/retiarii/operation.py
+5
-2
nni/retiarii/utils.py
nni/retiarii/utils.py
+6
-1
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+567
-0
No files found.
nni/retiarii/converter/op_types.py
View file @
0f88b86b
...
@@ -30,6 +30,10 @@ BasicOpsPT = {
...
@@ -30,6 +30,10 @@ BasicOpsPT = {
'aten::size'
:
'Size'
,
'aten::size'
:
'Size'
,
'aten::view'
:
'View'
,
'aten::view'
:
'View'
,
'aten::eq'
:
'Eq'
,
'aten::eq'
:
'Eq'
,
'aten::Bool'
:
'Bool'
,
'aten::empty'
:
'Empty'
,
'aten::zeros'
:
'Zeros'
,
'aten::chunk'
:
'Chunk'
,
'aten::add_'
:
'Add_'
# %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
'aten::add_'
:
'Add_'
# %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
}
...
...
nni/retiarii/operation.py
View file @
0f88b86b
...
@@ -121,6 +121,8 @@ class PyTorchOperation(Operation):
...
@@ -121,6 +121,8 @@ class PyTorchOperation(Operation):
return
f
'
{
output
}
=
{
value
}
'
return
f
'
{
output
}
=
{
value
}
'
elif
self
.
type
==
'prim::ListConstruct'
:
elif
self
.
type
==
'prim::ListConstruct'
:
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
elif
self
.
type
==
'prim::GetAttr'
:
return
f
"
{
output
}
=
{
self
.
parameters
[
'input'
]
}
.
{
self
.
parameters
[
'name'
]
}
"
elif
self
.
type
==
'aten::mean'
:
elif
self
.
type
==
'aten::mean'
:
return
f
'
{
output
}
= torch.mean(
{
inputs
[
0
]
}
,
{
", "
.
join
(
inputs
[
1
:
-
1
])
}
, out=
{
inputs
[
-
1
]
}
)'
return
f
'
{
output
}
= torch.mean(
{
inputs
[
0
]
}
,
{
", "
.
join
(
inputs
[
1
:
-
1
])
}
, out=
{
inputs
[
-
1
]
}
)'
elif
self
.
type
==
'aten::__getitem__'
:
elif
self
.
type
==
'aten::__getitem__'
:
...
@@ -133,8 +135,7 @@ class PyTorchOperation(Operation):
...
@@ -133,8 +135,7 @@ class PyTorchOperation(Operation):
assert
len
(
inputs
)
==
2
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::add'
:
elif
self
.
type
==
'aten::add'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= '
+
' + '
.
join
(
inputs
)
return
f
'
{
output
}
=
{
inputs
[
0
]
}
+
{
inputs
[
1
]
}
'
elif
self
.
type
==
OpTypeName
.
MergedSlice
:
elif
self
.
type
==
OpTypeName
.
MergedSlice
:
assert
(
len
(
inputs
)
-
1
)
%
4
==
0
assert
(
len
(
inputs
)
-
1
)
%
4
==
0
slices
=
[]
slices
=
[]
...
@@ -151,6 +152,8 @@ class PyTorchOperation(Operation):
...
@@ -151,6 +152,8 @@ class PyTorchOperation(Operation):
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
1
]
}
)'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::slice'
:
elif
self
.
type
==
'aten::slice'
:
raise
RuntimeError
(
'not supposed to have aten::slice operation'
)
raise
RuntimeError
(
'not supposed to have aten::slice operation'
)
elif
self
.
type
==
'aten::Bool'
:
return
f
'
{
output
}
= bool(
{
inputs
[
0
]
}
)'
else
:
else
:
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
...
...
nni/retiarii/utils.py
View file @
0f88b86b
...
@@ -27,6 +27,11 @@ def get_records():
...
@@ -27,6 +27,11 @@ def get_records():
return
_records
return
_records
def
clear_records
():
global
_records
_records
=
{}
def
add_record
(
key
,
value
):
def
add_record
(
key
,
value
):
"""
"""
"""
"""
...
@@ -56,7 +61,7 @@ def _blackbox_cls(cls, module_name, register_format=None):
...
@@ -56,7 +61,7 @@ def _blackbox_cls(cls, module_name, register_format=None):
# eject un-serializable arguments
# eject un-serializable arguments
for
k
in
list
(
full_args
.
keys
()):
for
k
in
list
(
full_args
.
keys
()):
# The list is not complete and does not support nested cases.
# The list is not complete and does not support nested cases.
if
not
isinstance
(
full_args
[
k
],
(
int
,
float
,
str
,
dict
,
list
)):
if
not
isinstance
(
full_args
[
k
],
(
int
,
float
,
str
,
dict
,
list
,
tuple
)):
if
not
(
register_format
==
'full'
and
k
==
'model'
):
if
not
(
register_format
==
'full'
and
k
==
'model'
):
# no warning if it is base model in trainer
# no warning if it is base model in trainer
warnings
.
warn
(
f
'
{
cls
}
has un-serializable arguments
{
k
}
whose value is
{
full_args
[
k
]
}
.
\
warnings
.
warn
(
f
'
{
cls
}
has un-serializable arguments
{
k
}
whose value is
{
full_args
[
k
]
}
.
\
...
...
test/ut/retiarii/test_convert.py
0 → 100644
View file @
0f88b86b
"""
Reference: We use tested models from https://github.com/pytorch/pytorch/blob/master/test/jit/test_models.py.
"""
import
os
import
sys
import
unittest
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
blackbox_module
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
class
MnistNet
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MnistNet
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
10
,
kernel_size
=
5
)
self
.
conv2
=
nn
.
Conv2d
(
10
,
20
,
kernel_size
=
5
)
self
.
conv2_drop
=
nn
.
Dropout2d
()
self
.
fc1
=
nn
.
Linear
(
320
,
50
)
self
.
fc2
=
nn
.
Linear
(
50
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
F
.
max_pool2d
(
self
.
conv1
(
x
),
2
))
x
=
F
.
relu
(
F
.
max_pool2d
(
self
.
conv2_drop
(
self
.
conv2
(
x
)),
2
))
x
=
x
.
view
(
-
1
,
320
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
F
.
dropout
(
x
,
training
=
self
.
training
)
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
class
TestConvert
(
unittest
.
TestCase
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
cv
in
current_values
:
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
remove
(
cv
)
break
return
result
def
checkExportImport
(
self
,
model
,
input
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
self
.
assertEqual
(
len
(
converted_output
),
len
(
expected_output
))
for
a
,
b
in
zip
(
converted_output
,
expected_output
):
self
.
assertLess
((
a
-
b
).
abs
().
max
().
item
(),
1E-4
)
return
converted_model
def
setUp
(
self
):
# FIXME
import
nni.retiarii.debug_configs
nni
.
retiarii
.
debug_configs
.
framework
=
'pytorch'
def
test_dcgan_models
(
self
):
class
DCGANGenerator
(
nn
.
Module
):
def
__init__
(
self
,
nz
,
ngf
,
nc
):
super
(
DCGANGenerator
,
self
).
__init__
()
self
.
main
=
nn
.
Sequential
(
# input is Z, going into a convolution
nn
.
ConvTranspose2d
(
nz
,
ngf
*
8
,
4
,
1
,
0
,
bias
=
False
),
nn
.
BatchNorm2d
(
ngf
*
8
),
nn
.
ReLU
(
True
),
# state size. (ngf*8) x 4 x 4
nn
.
ConvTranspose2d
(
ngf
*
8
,
ngf
*
4
,
4
,
2
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
ngf
*
4
),
nn
.
ReLU
(
True
),
# state size. (ngf*4) x 8 x 8
nn
.
ConvTranspose2d
(
ngf
*
4
,
ngf
*
2
,
4
,
2
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
ngf
*
2
),
nn
.
ReLU
(
True
),
# state size. (ngf*2) x 16 x 16
nn
.
ConvTranspose2d
(
ngf
*
2
,
ngf
,
4
,
2
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
ngf
),
nn
.
ReLU
(
True
),
# state size. (ngf) x 32 x 32
nn
.
ConvTranspose2d
(
ngf
,
nc
,
4
,
2
,
1
,
bias
=
False
),
nn
.
Tanh
()
# state size. (nc) x 64 x 64
)
def
forward
(
self
,
input
):
return
self
.
main
(
input
)
class
DCGANDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
nc
,
ndf
):
super
(
DCGANDiscriminator
,
self
).
__init__
()
self
.
main
=
nn
.
Sequential
(
# input is (nc) x 64 x 64
nn
.
Conv2d
(
nc
,
ndf
,
4
,
2
,
1
,
bias
=
False
),
nn
.
LeakyReLU
(
0.2
,
inplace
=
True
),
# state size. (ndf) x 32 x 32
nn
.
Conv2d
(
ndf
,
ndf
*
2
,
4
,
2
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
ndf
*
2
),
nn
.
LeakyReLU
(
0.2
,
inplace
=
True
),
# state size. (ndf*2) x 16 x 16
nn
.
Conv2d
(
ndf
*
2
,
ndf
*
4
,
4
,
2
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
ndf
*
4
),
nn
.
LeakyReLU
(
0.2
,
inplace
=
True
),
# state size. (ndf*4) x 8 x 8
nn
.
Conv2d
(
ndf
*
4
,
ndf
*
8
,
4
,
2
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
ndf
*
8
),
nn
.
LeakyReLU
(
0.2
,
inplace
=
True
),
# state size. (ndf*8) x 4 x 4
nn
.
Conv2d
(
ndf
*
8
,
1
,
4
,
1
,
0
,
bias
=
False
),
nn
.
Sigmoid
()
)
def
forward
(
self
,
input
):
return
self
.
main
(
input
).
view
(
-
1
,
1
).
squeeze
(
1
)
bs
,
nz
,
ngf
,
nc
,
ndf
=
5
,
6
,
9
,
3
,
10
input
=
(
torch
.
rand
(
bs
,
nz
,
1
,
1
),)
model
=
DCGANGenerator
(
nz
,
ngf
,
nc
)
self
.
checkExportImport
(
model
,
input
)
@
unittest
.
skip
(
'this test has a if condition that needs to be handle'
)
# FIXME
def
test_neural_style
(
self
):
class
TransformerNet
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
(
TransformerNet
,
self
).
__init__
()
# Initial convolution layers
self
.
conv1
=
ConvLayer
(
3
,
32
,
kernel_size
=
9
,
stride
=
1
)
self
.
in1
=
torch
.
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
conv2
=
ConvLayer
(
32
,
64
,
kernel_size
=
3
,
stride
=
2
)
self
.
in2
=
torch
.
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
conv3
=
ConvLayer
(
64
,
128
,
kernel_size
=
3
,
stride
=
2
)
self
.
in3
=
torch
.
nn
.
InstanceNorm2d
(
128
,
affine
=
True
)
# Residual layers
self
.
res1
=
ResidualBlock
(
128
)
self
.
res2
=
ResidualBlock
(
128
)
self
.
res3
=
ResidualBlock
(
128
)
self
.
res4
=
ResidualBlock
(
128
)
self
.
res5
=
ResidualBlock
(
128
)
# Upsampling Layers
self
.
deconv1
=
UpsampleConvLayer
(
128
,
64
,
kernel_size
=
3
,
stride
=
1
,
upsample
=
2
)
self
.
in4
=
torch
.
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
deconv2
=
UpsampleConvLayer
(
64
,
32
,
kernel_size
=
3
,
stride
=
1
,
upsample
=
2
)
self
.
in5
=
torch
.
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
deconv3
=
ConvLayer
(
32
,
3
,
kernel_size
=
9
,
stride
=
1
)
# Non-linearities
self
.
relu
=
torch
.
nn
.
ReLU
()
def
forward
(
self
,
X
):
y
=
self
.
relu
(
self
.
in1
(
self
.
conv1
(
X
)))
y
=
self
.
relu
(
self
.
in2
(
self
.
conv2
(
y
)))
y
=
self
.
relu
(
self
.
in3
(
self
.
conv3
(
y
)))
y
=
self
.
res1
(
y
)
y
=
self
.
res2
(
y
)
y
=
self
.
res3
(
y
)
y
=
self
.
res4
(
y
)
y
=
self
.
res5
(
y
)
y
=
self
.
relu
(
self
.
in4
(
self
.
deconv1
(
y
)))
y
=
self
.
relu
(
self
.
in5
(
self
.
deconv2
(
y
)))
y
=
self
.
deconv3
(
y
)
return
y
class
ConvLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
):
super
(
ConvLayer
,
self
).
__init__
()
reflection_padding
=
kernel_size
//
2
self
.
reflection_pad
=
torch
.
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
conv2d
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
def
forward
(
self
,
x
):
out
=
self
.
reflection_pad
(
x
)
out
=
self
.
conv2d
(
out
)
return
out
class
ResidualBlock
(
torch
.
nn
.
Module
):
"""ResidualBlock
introduced in: https://arxiv.org/abs/1512.03385
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
"""
def
__init__
(
self
,
channels
):
super
(
ResidualBlock
,
self
).
__init__
()
self
.
conv1
=
ConvLayer
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
in1
=
torch
.
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
conv2
=
ConvLayer
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
in2
=
torch
.
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
relu
=
torch
.
nn
.
ReLU
()
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
relu
(
self
.
in1
(
self
.
conv1
(
x
)))
out
=
self
.
in2
(
self
.
conv2
(
out
))
out
=
out
+
residual
return
out
class
UpsampleConvLayer
(
torch
.
nn
.
Module
):
"""UpsampleConvLayer
Upsamples the input and then does a convolution. This method gives better results
compared to ConvTranspose2d.
ref: http://distill.pub/2016/deconv-checkerboard/
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
upsample
=
None
):
super
(
UpsampleConvLayer
,
self
).
__init__
()
self
.
upsample
=
upsample
if
upsample
:
self
.
upsample_layer
=
torch
.
nn
.
Upsample
(
mode
=
'nearest'
,
scale_factor
=
upsample
)
reflection_padding
=
kernel_size
//
2
self
.
reflection_pad
=
torch
.
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
conv2d
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
def
forward
(
self
,
x
):
x_in
=
x
if
self
.
upsample
:
x_in
=
self
.
upsample_layer
(
x_in
)
out
=
self
.
reflection_pad
(
x_in
)
out
=
self
.
conv2d
(
out
)
return
out
model
=
TransformerNet
()
input
=
(
torch
.
rand
(
5
,
3
,
16
,
16
),)
self
.
checkExportImport
(
model
,
input
)
def
test_mnist
(
self
):
# eval() is present because dropout makes this nondeterministic
self
.
checkExportImport
(
MnistNet
().
eval
(),
(
torch
.
rand
(
5
,
1
,
28
,
28
),))
def
test_reinforcement_learning
(
self
):
class
Policy
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Policy
,
self
).
__init__
()
self
.
affine1
=
nn
.
Linear
(
4
,
128
)
self
.
affine2
=
nn
.
Linear
(
128
,
2
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
affine1
(
x
))
action_scores
=
self
.
affine2
(
x
)
return
F
.
softmax
(
action_scores
,
dim
=
1
)
self
.
checkExportImport
(
Policy
(),
(
torch
.
rand
(
1
,
4
),))
@
unittest
.
skip
(
'Replaced init error.'
)
# FIXME
def
test_snli
(
self
):
class
Bottle
(
nn
.
Module
):
def
forward
(
self
,
input
):
if
len
(
input
.
size
())
<=
2
:
return
super
(
Bottle
,
self
).
forward
(
input
)
size
=
input
.
size
()[:
2
]
out
=
super
(
Bottle
,
self
).
forward
(
input
.
view
(
size
[
0
]
*
size
[
1
],
-
1
))
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
class
Linear
(
Bottle
,
nn
.
Linear
):
pass
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
Encoder
,
self
).
__init__
()
self
.
config
=
config
input_size
=
config
.
d_proj
if
config
.
projection
else
config
.
d_embed
dropout
=
0
if
config
.
n_layers
==
1
else
config
.
dp_ratio
self
.
rnn
=
nn
.
LSTM
(
input_size
=
input_size
,
hidden_size
=
config
.
d_hidden
,
num_layers
=
config
.
n_layers
,
dropout
=
dropout
,
bidirectional
=
config
.
birnn
)
def
forward
(
self
,
inputs
):
batch_size
=
inputs
.
size
()[
1
]
state_shape
=
self
.
config
.
n_cells
,
batch_size
,
self
.
config
.
d_hidden
h0
=
c0
=
inputs
.
new_zeros
(
state_shape
)
outputs
,
(
ht
,
ct
)
=
self
.
rnn
(
inputs
,
(
h0
,
c0
))
return
ht
[
-
1
]
if
not
self
.
config
.
birnn
else
ht
[
-
2
:].
transpose
(
0
,
1
).
contiguous
().
view
(
batch_size
,
-
1
)
class
SNLIClassifier
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
SNLIClassifier
,
self
).
__init__
()
self
.
config
=
config
self
.
embed
=
nn
.
Embedding
(
config
.
n_embed
,
config
.
d_embed
)
self
.
projection
=
Linear
(
config
.
d_embed
,
config
.
d_proj
)
self
.
encoder
=
Encoder
(
config
)
self
.
dropout
=
nn
.
Dropout
(
p
=
config
.
dp_ratio
)
self
.
relu
=
nn
.
ReLU
()
seq_in_size
=
2
*
config
.
d_hidden
if
self
.
config
.
birnn
:
seq_in_size
*=
2
lin_config
=
[
seq_in_size
]
*
2
self
.
out
=
nn
.
Sequential
(
Linear
(
*
lin_config
),
self
.
relu
,
self
.
dropout
,
Linear
(
*
lin_config
),
self
.
relu
,
self
.
dropout
,
Linear
(
*
lin_config
),
self
.
relu
,
self
.
dropout
,
Linear
(
seq_in_size
,
config
.
d_out
))
def
forward
(
self
,
premise
,
hypothesis
):
prem_embed
=
self
.
embed
(
premise
)
hypo_embed
=
self
.
embed
(
hypothesis
)
if
self
.
config
.
fix_emb
:
prem_embed
=
prem_embed
.
detach
()
hypo_embed
=
hypo_embed
.
detach
()
if
self
.
config
.
projection
:
prem_embed
=
self
.
relu
(
self
.
projection
(
prem_embed
))
hypo_embed
=
self
.
relu
(
self
.
projection
(
hypo_embed
))
premise
=
self
.
encoder
(
prem_embed
)
hypothesis
=
self
.
encoder
(
hypo_embed
)
scores
=
self
.
out
(
torch
.
cat
([
premise
,
hypothesis
],
1
))
return
scores
class
Config
:
n_embed
=
100
d_embed
=
100
d_proj
=
300
dp_ratio
=
0.0
# For deterministic testing TODO: change by fixing seed in checkTrace?
d_hidden
=
30
birnn
=
True
d_out
=
300
fix_emb
=
True
projection
=
True
n_layers
=
2
n_cells
=
4
# 2 * n_layers because birnn = True
premise
=
torch
.
LongTensor
(
48
,
64
).
random_
(
0
,
100
)
hypothesis
=
torch
.
LongTensor
(
24
,
64
).
random_
(
0
,
100
)
self
.
checkExportImport
(
SNLIClassifier
(
Config
()),
(
premise
,
hypothesis
))
def
test_super_resolution
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
,
upscale_factor
):
super
(
Net
,
self
).
__init__
()
self
.
relu
=
nn
.
ReLU
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
64
,
(
5
,
5
),
(
1
,
1
),
(
2
,
2
))
self
.
conv2
=
nn
.
Conv2d
(
64
,
64
,
(
3
,
3
),
(
1
,
1
),
(
1
,
1
))
self
.
conv3
=
nn
.
Conv2d
(
64
,
32
,
(
3
,
3
),
(
1
,
1
),
(
1
,
1
))
self
.
conv4
=
nn
.
Conv2d
(
32
,
upscale_factor
**
2
,
(
3
,
3
),
(
1
,
1
),
(
1
,
1
))
self
.
pixel_shuffle
=
nn
.
PixelShuffle
(
upscale_factor
)
def
forward
(
self
,
x
):
x
=
self
.
relu
(
self
.
conv1
(
x
))
x
=
self
.
relu
(
self
.
conv2
(
x
))
x
=
self
.
relu
(
self
.
conv3
(
x
))
x
=
self
.
pixel_shuffle
(
self
.
conv4
(
x
))
return
x
net
=
Net
(
upscale_factor
=
4
)
self
.
checkExportImport
(
net
,
(
torch
.
rand
(
5
,
1
,
32
,
32
),))
@
unittest
.
skip
(
'Need to support operator prim::ListUnpack'
)
# FIXME
def
test_time_sequence_prediction
(
self
):
class
Sequence
(
torch
.
jit
.
ScriptModule
):
def
__init__
(
self
):
super
(
Sequence
,
self
).
__init__
()
self
.
lstm1
=
nn
.
LSTMCell
(
1
,
51
)
self
.
lstm2
=
nn
.
LSTMCell
(
51
,
51
)
self
.
linear
=
nn
.
Linear
(
51
,
1
)
@
torch
.
jit
.
script_method
def
forward
(
self
,
input
):
# TODO: add future as input with default val
# see https://github.com/pytorch/pytorch/issues/8724
outputs
=
torch
.
empty
((
3
,
0
))
h_t
=
torch
.
zeros
((
3
,
51
))
c_t
=
torch
.
zeros
((
3
,
51
))
h_t2
=
torch
.
zeros
((
3
,
51
))
c_t2
=
torch
.
zeros
((
3
,
51
))
output
=
torch
.
zeros
([
3
,
51
])
future
=
2
# TODO: chunk call should appear as the for loop iterable
# We hard-code it to 4 for now.
a
,
b
,
c
,
d
=
input
.
chunk
(
input
.
size
(
1
),
dim
=
1
)
for
input_t
in
(
a
,
b
,
c
,
d
):
h_t
,
c_t
=
self
.
lstm1
(
input_t
,
(
h_t
,
c_t
))
h_t2
,
c_t2
=
self
.
lstm2
(
h_t
,
(
h_t2
,
c_t2
))
output
=
self
.
linear
(
h_t2
)
outputs
=
torch
.
cat
((
outputs
,
output
),
1
)
for
_
in
range
(
future
):
# if we should predict the future
h_t
,
c_t
=
self
.
lstm1
(
output
,
(
h_t
,
c_t
))
h_t2
,
c_t2
=
self
.
lstm2
(
h_t
,
(
h_t2
,
c_t2
))
output
=
self
.
linear
(
h_t2
)
outputs
=
torch
.
cat
((
outputs
,
output
),
1
)
return
outputs
class
Traced
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Traced
,
self
).
__init__
()
self
.
seq
=
Sequence
()
def
forward
(
self
,
input
):
return
self
.
seq
.
forward
(
input
)
self
.
checkExportImport
(
Traced
(),
(
torch
.
rand
(
3
,
4
),))
@
unittest
.
skip
(
'Unsupported callmethod encode'
)
# FIXME
def
test_vae
(
self
):
class
VAE
(
nn
.
Module
):
def
__init__
(
self
):
super
(
VAE
,
self
).
__init__
()
self
.
fc1
=
nn
.
Linear
(
784
,
400
)
self
.
fc21
=
nn
.
Linear
(
400
,
20
)
self
.
fc22
=
nn
.
Linear
(
400
,
20
)
self
.
fc3
=
nn
.
Linear
(
20
,
400
)
self
.
fc4
=
nn
.
Linear
(
400
,
784
)
def
encode
(
self
,
x
):
h1
=
F
.
relu
(
self
.
fc1
(
x
))
return
self
.
fc21
(
h1
),
self
.
fc22
(
h1
)
def
reparameterize
(
self
,
mu
,
logvar
):
if
self
.
training
:
std
=
torch
.
exp
(
0.5
*
logvar
)
eps
=
torch
.
randn_like
(
std
)
return
eps
.
mul
(
std
).
add_
(
mu
)
else
:
return
mu
def
decode
(
self
,
z
):
h3
=
F
.
relu
(
self
.
fc3
(
z
))
return
torch
.
sigmoid
(
self
.
fc4
(
h3
))
def
forward
(
self
,
x
):
mu
,
logvar
=
self
.
encode
(
x
.
view
(
-
1
,
784
))
z
=
self
.
reparameterize
(
mu
,
logvar
)
return
self
.
decode
(
z
),
mu
,
logvar
self
.
checkExportImport
(
VAE
().
eval
(),
(
torch
.
rand
(
128
,
1
,
28
,
28
),))
@
unittest
.
skip
(
'torchvision models are not supported yet'
)
# FIXME
def
test_torchvision_resnet18
(
self
):
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
(),
(
torch
.
ones
(
1
,
3
,
224
,
224
),))
@
unittest
.
skip
(
'Unsupported CallMethod _forward_impl'
)
# FIXME
def
test_resnet
(
self
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
"""1x1 convolution"""
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
)
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
):
"""3x3 convolution with padding"""
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias
=
False
)
class
BasicBlock
(
torch
.
jit
.
ScriptModule
):
expansion
=
1
__constants__
=
[
'downsample'
]
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
)
self
.
bn1
=
nn
.
BatchNorm2d
(
planes
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
conv2
=
conv3x3
(
planes
,
planes
)
self
.
bn2
=
nn
.
BatchNorm2d
(
planes
)
self
.
downsample
=
downsample
self
.
stride
=
stride
@
torch
.
jit
.
script_method
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
out
+=
residual
out
=
self
.
relu
(
out
)
return
out
class
ResNet
(
torch
.
jit
.
ScriptModule
):
__constants__
=
[
'layer1'
,
'layer2'
,
'layer3'
,
'layer4'
]
def
__init__
(
self
,
block
,
layers
,
num_classes
=
1000
):
super
(
ResNet
,
self
).
__init__
()
self
.
inplanes
=
64
self
.
conv1
=
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
64
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
layer1
=
self
.
_make_layer
(
block
,
64
,
layers
[
0
])
self
.
layer2
=
self
.
_make_layer
(
block
,
128
,
layers
[
1
],
stride
=
2
)
self
.
layer3
=
self
.
_make_layer
(
block
,
256
,
layers
[
2
],
stride
=
2
)
self
.
layer4
=
self
.
_make_layer
(
block
,
512
,
layers
[
3
],
stride
=
2
)
self
.
avgpool
=
nn
.
AdaptiveAvgPool2d
((
1
,
1
))
self
.
fc
=
nn
.
Linear
(
512
*
block
.
expansion
,
num_classes
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
torch
.
nn
.
init
.
kaiming_normal_
(
m
.
weight
,
mode
=
'fan_out'
,
nonlinearity
=
'relu'
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
torch
.
nn
.
init
.
constant_
(
m
.
weight
,
1
)
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
_make_layer
(
self
,
block
,
planes
,
blocks
,
stride
=
1
):
downsample
=
None
if
stride
!=
1
or
self
.
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
conv1x1
(
self
.
inplanes
,
planes
*
block
.
expansion
,
stride
),
nn
.
BatchNorm2d
(
planes
*
block
.
expansion
),
)
layers
=
[]
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
stride
,
downsample
))
self
.
inplanes
=
planes
*
block
.
expansion
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
self
.
inplanes
,
planes
))
return
nn
.
Sequential
(
*
layers
)
@
torch
.
jit
.
script_method
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool
(
x
)
x
=
self
.
layer1
(
x
)
x
=
self
.
layer2
(
x
)
x
=
self
.
layer3
(
x
)
x
=
self
.
layer4
(
x
)
x
=
self
.
avgpool
(
x
)
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
x
=
self
.
fc
(
x
)
return
x
resnet18
=
ResNet
(
BasicBlock
,
[
2
,
2
,
2
,
2
])
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
(),
(
torch
.
randn
(
1
,
3
,
224
,
224
),))
@
unittest
.
skip
(
'torchvision models are not supported yet'
)
# FIXME
def
test_alexnet
(
self
):
x
=
torch
.
ones
(
1
,
3
,
224
,
224
)
model
=
torchvision
.
models
.
AlexNet
()
self
.
checkExportImport
(
model
,
(
x
,))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment