Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
0f88b86b
Unverified
Commit
0f88b86b
authored
Jan 05, 2021
by
Yuge Zhang
Committed by
GitHub
Jan 05, 2021
Browse files
Retiarii graph and code generation test (#3231)
parent
4fae3ed9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
582 additions
and
3 deletions
+582
-3
nni/retiarii/converter/op_types.py
nni/retiarii/converter/op_types.py
+4
-0
nni/retiarii/operation.py
nni/retiarii/operation.py
+5
-2
nni/retiarii/utils.py
nni/retiarii/utils.py
+6
-1
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+567
-0
No files found.
nni/retiarii/converter/op_types.py
View file @
0f88b86b
...
@@ -30,6 +30,10 @@ BasicOpsPT = {
...
@@ -30,6 +30,10 @@ BasicOpsPT = {
'aten::size'
:
'Size'
,
'aten::size'
:
'Size'
,
'aten::view'
:
'View'
,
'aten::view'
:
'View'
,
'aten::eq'
:
'Eq'
,
'aten::eq'
:
'Eq'
,
'aten::Bool'
:
'Bool'
,
'aten::empty'
:
'Empty'
,
'aten::zeros'
:
'Zeros'
,
'aten::chunk'
:
'Chunk'
,
'aten::add_'
:
'Add_'
# %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
'aten::add_'
:
'Add_'
# %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
}
...
...
nni/retiarii/operation.py
View file @
0f88b86b
...
@@ -121,6 +121,8 @@ class PyTorchOperation(Operation):
...
@@ -121,6 +121,8 @@ class PyTorchOperation(Operation):
return
f
'
{
output
}
=
{
value
}
'
return
f
'
{
output
}
=
{
value
}
'
elif
self
.
type
==
'prim::ListConstruct'
:
elif
self
.
type
==
'prim::ListConstruct'
:
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
elif
self
.
type
==
'prim::GetAttr'
:
return
f
"
{
output
}
=
{
self
.
parameters
[
'input'
]
}
.
{
self
.
parameters
[
'name'
]
}
"
elif
self
.
type
==
'aten::mean'
:
elif
self
.
type
==
'aten::mean'
:
return
f
'
{
output
}
= torch.mean(
{
inputs
[
0
]
}
,
{
", "
.
join
(
inputs
[
1
:
-
1
])
}
, out=
{
inputs
[
-
1
]
}
)'
return
f
'
{
output
}
= torch.mean(
{
inputs
[
0
]
}
,
{
", "
.
join
(
inputs
[
1
:
-
1
])
}
, out=
{
inputs
[
-
1
]
}
)'
elif
self
.
type
==
'aten::__getitem__'
:
elif
self
.
type
==
'aten::__getitem__'
:
...
@@ -133,8 +135,7 @@ class PyTorchOperation(Operation):
...
@@ -133,8 +135,7 @@ class PyTorchOperation(Operation):
assert
len
(
inputs
)
==
2
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::add'
:
elif
self
.
type
==
'aten::add'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= '
+
' + '
.
join
(
inputs
)
return
f
'
{
output
}
=
{
inputs
[
0
]
}
+
{
inputs
[
1
]
}
'
elif
self
.
type
==
OpTypeName
.
MergedSlice
:
elif
self
.
type
==
OpTypeName
.
MergedSlice
:
assert
(
len
(
inputs
)
-
1
)
%
4
==
0
assert
(
len
(
inputs
)
-
1
)
%
4
==
0
slices
=
[]
slices
=
[]
...
@@ -151,6 +152,8 @@ class PyTorchOperation(Operation):
...
@@ -151,6 +152,8 @@ class PyTorchOperation(Operation):
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
1
]
}
)'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::slice'
:
elif
self
.
type
==
'aten::slice'
:
raise
RuntimeError
(
'not supposed to have aten::slice operation'
)
raise
RuntimeError
(
'not supposed to have aten::slice operation'
)
elif
self
.
type
==
'aten::Bool'
:
return
f
'
{
output
}
= bool(
{
inputs
[
0
]
}
)'
else
:
else
:
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
...
...
nni/retiarii/utils.py
View file @
0f88b86b
...
@@ -27,6 +27,11 @@ def get_records():
...
@@ -27,6 +27,11 @@ def get_records():
return
_records
return
_records
def
clear_records
():
global
_records
_records
=
{}
def
add_record
(
key
,
value
):
def
add_record
(
key
,
value
):
"""
"""
"""
"""
...
@@ -56,7 +61,7 @@ def _blackbox_cls(cls, module_name, register_format=None):
...
@@ -56,7 +61,7 @@ def _blackbox_cls(cls, module_name, register_format=None):
# eject un-serializable arguments
# eject un-serializable arguments
for
k
in
list
(
full_args
.
keys
()):
for
k
in
list
(
full_args
.
keys
()):
# The list is not complete and does not support nested cases.
# The list is not complete and does not support nested cases.
if
not
isinstance
(
full_args
[
k
],
(
int
,
float
,
str
,
dict
,
list
)):
if
not
isinstance
(
full_args
[
k
],
(
int
,
float
,
str
,
dict
,
list
,
tuple
)):
if
not
(
register_format
==
'full'
and
k
==
'model'
):
if
not
(
register_format
==
'full'
and
k
==
'model'
):
# no warning if it is base model in trainer
# no warning if it is base model in trainer
warnings
.
warn
(
f
'
{
cls
}
has un-serializable arguments
{
k
}
whose value is
{
full_args
[
k
]
}
.
\
warnings
.
warn
(
f
'
{
cls
}
has un-serializable arguments
{
k
}
whose value is
{
full_args
[
k
]
}
.
\
...
...
test/ut/retiarii/test_convert.py
0 → 100644
View file @
0f88b86b
"""
Reference: We use tested models from https://github.com/pytorch/pytorch/blob/master/test/jit/test_models.py.
"""
import
os
import
sys
import
unittest
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
blackbox_module
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
class
MnistNet
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MnistNet
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
10
,
kernel_size
=
5
)
self
.
conv2
=
nn
.
Conv2d
(
10
,
20
,
kernel_size
=
5
)
self
.
conv2_drop
=
nn
.
Dropout2d
()
self
.
fc1
=
nn
.
Linear
(
320
,
50
)
self
.
fc2
=
nn
.
Linear
(
50
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
F
.
max_pool2d
(
self
.
conv1
(
x
),
2
))
x
=
F
.
relu
(
F
.
max_pool2d
(
self
.
conv2_drop
(
self
.
conv2
(
x
)),
2
))
x
=
x
.
view
(
-
1
,
320
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
F
.
dropout
(
x
,
training
=
self
.
training
)
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
class
TestConvert
(
unittest
.
TestCase
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
cv
in
current_values
:
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
remove
(
cv
)
break
return
result
def
checkExportImport
(
self
,
model
,
input
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
self
.
assertEqual
(
len
(
converted_output
),
len
(
expected_output
))
for
a
,
b
in
zip
(
converted_output
,
expected_output
):
self
.
assertLess
((
a
-
b
).
abs
().
max
().
item
(),
1E-4
)
return
converted_model
def
setUp
(
self
):
# FIXME
import
nni.retiarii.debug_configs
nni
.
retiarii
.
debug_configs
.
framework
=
'pytorch'
def
test_dcgan_models
(
self
):
class
DCGANGenerator
(
nn
.
Module
):
def
__init__
(
self
,
nz
,
ngf
,
nc
):
super
(
DCGANGenerator
,
self
).
__init__
()
self
.
main
=
nn
.
Sequential
(
# input is Z, going into a convolution
nn
.
ConvTranspose2d
(
nz
,
ngf
*
8
,
4
,
1
,
0
,
bias
=
False
),
nn
.
BatchNorm2d
(
ngf
*
8
),
nn
.
ReLU
(
True
),
# state size. (ngf*8) x 4 x 4
nn
.
ConvTranspose2d
(
ngf
*
8
,
ngf
*
4
,
4
,
2
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
ngf
*
4
),
nn
.
ReLU
(
True
),
# state size. (ngf*4) x 8 x 8
nn
.
ConvTranspose2d
(
ngf
*
4
,
ngf
*
2
,
4
,
2
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
ngf
*
2
),
nn
.
ReLU
(
True
),
# state size. (ngf*2) x 16 x 16
nn
.
ConvTranspose2d
(
ngf
*
2
,
ngf
,
4
,
2
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
ngf
),
nn
.
ReLU
(
True
),
# state size. (ngf) x 32 x 32
nn
.
ConvTranspose2d
(
ngf
,
nc
,
4
,
2
,
1
,
bias
=
False
),
nn
.
Tanh
()
# state size. (nc) x 64 x 64
)
def
forward
(
self
,
input
):
return
self
.
main
(
input
)
class
DCGANDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
nc
,
ndf
):
super
(
DCGANDiscriminator
,
self
).
__init__
()
self
.
main
=
nn
.
Sequential
(
# input is (nc) x 64 x 64
nn
.
Conv2d
(
nc
,
ndf
,
4
,
2
,
1
,
bias
=
False
),
nn
.
LeakyReLU
(
0.2
,
inplace
=
True
),
# state size. (ndf) x 32 x 32
nn
.
Conv2d
(
ndf
,
ndf
*
2
,
4
,
2
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
ndf
*
2
),
nn
.
LeakyReLU
(
0.2
,
inplace
=
True
),
# state size. (ndf*2) x 16 x 16
nn
.
Conv2d
(
ndf
*
2
,
ndf
*
4
,
4
,
2
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
ndf
*
4
),
nn
.
LeakyReLU
(
0.2
,
inplace
=
True
),
# state size. (ndf*4) x 8 x 8
nn
.
Conv2d
(
ndf
*
4
,
ndf
*
8
,
4
,
2
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
ndf
*
8
),
nn
.
LeakyReLU
(
0.2
,
inplace
=
True
),
# state size. (ndf*8) x 4 x 4
nn
.
Conv2d
(
ndf
*
8
,
1
,
4
,
1
,
0
,
bias
=
False
),
nn
.
Sigmoid
()
)
def
forward
(
self
,
input
):
return
self
.
main
(
input
).
view
(
-
1
,
1
).
squeeze
(
1
)
bs
,
nz
,
ngf
,
nc
,
ndf
=
5
,
6
,
9
,
3
,
10
input
=
(
torch
.
rand
(
bs
,
nz
,
1
,
1
),)
model
=
DCGANGenerator
(
nz
,
ngf
,
nc
)
self
.
checkExportImport
(
model
,
input
)
@
unittest
.
skip
(
'this test has a if condition that needs to be handle'
)
# FIXME
def
test_neural_style
(
self
):
class
TransformerNet
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
(
TransformerNet
,
self
).
__init__
()
# Initial convolution layers
self
.
conv1
=
ConvLayer
(
3
,
32
,
kernel_size
=
9
,
stride
=
1
)
self
.
in1
=
torch
.
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
conv2
=
ConvLayer
(
32
,
64
,
kernel_size
=
3
,
stride
=
2
)
self
.
in2
=
torch
.
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
conv3
=
ConvLayer
(
64
,
128
,
kernel_size
=
3
,
stride
=
2
)
self
.
in3
=
torch
.
nn
.
InstanceNorm2d
(
128
,
affine
=
True
)
# Residual layers
self
.
res1
=
ResidualBlock
(
128
)
self
.
res2
=
ResidualBlock
(
128
)
self
.
res3
=
ResidualBlock
(
128
)
self
.
res4
=
ResidualBlock
(
128
)
self
.
res5
=
ResidualBlock
(
128
)
# Upsampling Layers
self
.
deconv1
=
UpsampleConvLayer
(
128
,
64
,
kernel_size
=
3
,
stride
=
1
,
upsample
=
2
)
self
.
in4
=
torch
.
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
deconv2
=
UpsampleConvLayer
(
64
,
32
,
kernel_size
=
3
,
stride
=
1
,
upsample
=
2
)
self
.
in5
=
torch
.
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
deconv3
=
ConvLayer
(
32
,
3
,
kernel_size
=
9
,
stride
=
1
)
# Non-linearities
self
.
relu
=
torch
.
nn
.
ReLU
()
def
forward
(
self
,
X
):
y
=
self
.
relu
(
self
.
in1
(
self
.
conv1
(
X
)))
y
=
self
.
relu
(
self
.
in2
(
self
.
conv2
(
y
)))
y
=
self
.
relu
(
self
.
in3
(
self
.
conv3
(
y
)))
y
=
self
.
res1
(
y
)
y
=
self
.
res2
(
y
)
y
=
self
.
res3
(
y
)
y
=
self
.
res4
(
y
)
y
=
self
.
res5
(
y
)
y
=
self
.
relu
(
self
.
in4
(
self
.
deconv1
(
y
)))
y
=
self
.
relu
(
self
.
in5
(
self
.
deconv2
(
y
)))
y
=
self
.
deconv3
(
y
)
return
y
class
ConvLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
):
super
(
ConvLayer
,
self
).
__init__
()
reflection_padding
=
kernel_size
//
2
self
.
reflection_pad
=
torch
.
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
conv2d
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
def
forward
(
self
,
x
):
out
=
self
.
reflection_pad
(
x
)
out
=
self
.
conv2d
(
out
)
return
out
class
ResidualBlock
(
torch
.
nn
.
Module
):
"""ResidualBlock
introduced in: https://arxiv.org/abs/1512.03385
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
"""
def
__init__
(
self
,
channels
):
super
(
ResidualBlock
,
self
).
__init__
()
self
.
conv1
=
ConvLayer
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
in1
=
torch
.
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
conv2
=
ConvLayer
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
in2
=
torch
.
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
relu
=
torch
.
nn
.
ReLU
()
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
relu
(
self
.
in1
(
self
.
conv1
(
x
)))
out
=
self
.
in2
(
self
.
conv2
(
out
))
out
=
out
+
residual
return
out
class
UpsampleConvLayer
(
torch
.
nn
.
Module
):
"""UpsampleConvLayer
Upsamples the input and then does a convolution. This method gives better results
compared to ConvTranspose2d.
ref: http://distill.pub/2016/deconv-checkerboard/
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
upsample
=
None
):
super
(
UpsampleConvLayer
,
self
).
__init__
()
self
.
upsample
=
upsample
if
upsample
:
self
.
upsample_layer
=
torch
.
nn
.
Upsample
(
mode
=
'nearest'
,
scale_factor
=
upsample
)
reflection_padding
=
kernel_size
//
2
self
.
reflection_pad
=
torch
.
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
conv2d
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
def
forward
(
self
,
x
):
x_in
=
x
if
self
.
upsample
:
x_in
=
self
.
upsample_layer
(
x_in
)
out
=
self
.
reflection_pad
(
x_in
)
out
=
self
.
conv2d
(
out
)
return
out
model
=
TransformerNet
()
input
=
(
torch
.
rand
(
5
,
3
,
16
,
16
),)
self
.
checkExportImport
(
model
,
input
)
def
test_mnist
(
self
):
# eval() is present because dropout makes this nondeterministic
self
.
checkExportImport
(
MnistNet
().
eval
(),
(
torch
.
rand
(
5
,
1
,
28
,
28
),))
def
test_reinforcement_learning
(
self
):
class
Policy
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Policy
,
self
).
__init__
()
self
.
affine1
=
nn
.
Linear
(
4
,
128
)
self
.
affine2
=
nn
.
Linear
(
128
,
2
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
affine1
(
x
))
action_scores
=
self
.
affine2
(
x
)
return
F
.
softmax
(
action_scores
,
dim
=
1
)
self
.
checkExportImport
(
Policy
(),
(
torch
.
rand
(
1
,
4
),))
@
unittest
.
skip
(
'Replaced init error.'
)
# FIXME
def
test_snli
(
self
):
class
Bottle
(
nn
.
Module
):
def
forward
(
self
,
input
):
if
len
(
input
.
size
())
<=
2
:
return
super
(
Bottle
,
self
).
forward
(
input
)
size
=
input
.
size
()[:
2
]
out
=
super
(
Bottle
,
self
).
forward
(
input
.
view
(
size
[
0
]
*
size
[
1
],
-
1
))
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
class
Linear
(
Bottle
,
nn
.
Linear
):
pass
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
Encoder
,
self
).
__init__
()
self
.
config
=
config
input_size
=
config
.
d_proj
if
config
.
projection
else
config
.
d_embed
dropout
=
0
if
config
.
n_layers
==
1
else
config
.
dp_ratio
self
.
rnn
=
nn
.
LSTM
(
input_size
=
input_size
,
hidden_size
=
config
.
d_hidden
,
num_layers
=
config
.
n_layers
,
dropout
=
dropout
,
bidirectional
=
config
.
birnn
)
def
forward
(
self
,
inputs
):
batch_size
=
inputs
.
size
()[
1
]
state_shape
=
self
.
config
.
n_cells
,
batch_size
,
self
.
config
.
d_hidden
h0
=
c0
=
inputs
.
new_zeros
(
state_shape
)
outputs
,
(
ht
,
ct
)
=
self
.
rnn
(
inputs
,
(
h0
,
c0
))
return
ht
[
-
1
]
if
not
self
.
config
.
birnn
else
ht
[
-
2
:].
transpose
(
0
,
1
).
contiguous
().
view
(
batch_size
,
-
1
)
class
SNLIClassifier
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
SNLIClassifier
,
self
).
__init__
()
self
.
config
=
config
self
.
embed
=
nn
.
Embedding
(
config
.
n_embed
,
config
.
d_embed
)
self
.
projection
=
Linear
(
config
.
d_embed
,
config
.
d_proj
)
self
.
encoder
=
Encoder
(
config
)
self
.
dropout
=
nn
.
Dropout
(
p
=
config
.
dp_ratio
)
self
.
relu
=
nn
.
ReLU
()
seq_in_size
=
2
*
config
.
d_hidden
if
self
.
config
.
birnn
:
seq_in_size
*=
2
lin_config
=
[
seq_in_size
]
*
2
self
.
out
=
nn
.
Sequential
(
Linear
(
*
lin_config
),
self
.
relu
,
self
.
dropout
,
Linear
(
*
lin_config
),
self
.
relu
,
self
.
dropout
,
Linear
(
*
lin_config
),
self
.
relu
,
self
.
dropout
,
Linear
(
seq_in_size
,
config
.
d_out
))
def
forward
(
self
,
premise
,
hypothesis
):
prem_embed
=
self
.
embed
(
premise
)
hypo_embed
=
self
.
embed
(
hypothesis
)
if
self
.
config
.
fix_emb
:
prem_embed
=
prem_embed
.
detach
()
hypo_embed
=
hypo_embed
.
detach
()
if
self
.
config
.
projection
:
prem_embed
=
self
.
relu
(
self
.
projection
(
prem_embed
))
hypo_embed
=
self
.
relu
(
self
.
projection
(
hypo_embed
))
premise
=
self
.
encoder
(
prem_embed
)
hypothesis
=
self
.
encoder
(
hypo_embed
)
scores
=
self
.
out
(
torch
.
cat
([
premise
,
hypothesis
],
1
))
return
scores
class
Config
:
n_embed
=
100
d_embed
=
100
d_proj
=
300
dp_ratio
=
0.0
# For deterministic testing TODO: change by fixing seed in checkTrace?
d_hidden
=
30
birnn
=
True
d_out
=
300
fix_emb
=
True
projection
=
True
n_layers
=
2
n_cells
=
4
# 2 * n_layers because birnn = True
premise
=
torch
.
LongTensor
(
48
,
64
).
random_
(
0
,
100
)
hypothesis
=
torch
.
LongTensor
(
24
,
64
).
random_
(
0
,
100
)
self
.
checkExportImport
(
SNLIClassifier
(
Config
()),
(
premise
,
hypothesis
))
def
test_super_resolution
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
,
upscale_factor
):
super
(
Net
,
self
).
__init__
()
self
.
relu
=
nn
.
ReLU
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
64
,
(
5
,
5
),
(
1
,
1
),
(
2
,
2
))
self
.
conv2
=
nn
.
Conv2d
(
64
,
64
,
(
3
,
3
),
(
1
,
1
),
(
1
,
1
))
self
.
conv3
=
nn
.
Conv2d
(
64
,
32
,
(
3
,
3
),
(
1
,
1
),
(
1
,
1
))
self
.
conv4
=
nn
.
Conv2d
(
32
,
upscale_factor
**
2
,
(
3
,
3
),
(
1
,
1
),
(
1
,
1
))
self
.
pixel_shuffle
=
nn
.
PixelShuffle
(
upscale_factor
)
def
forward
(
self
,
x
):
x
=
self
.
relu
(
self
.
conv1
(
x
))
x
=
self
.
relu
(
self
.
conv2
(
x
))
x
=
self
.
relu
(
self
.
conv3
(
x
))
x
=
self
.
pixel_shuffle
(
self
.
conv4
(
x
))
return
x
net
=
Net
(
upscale_factor
=
4
)
self
.
checkExportImport
(
net
,
(
torch
.
rand
(
5
,
1
,
32
,
32
),))
@
unittest
.
skip
(
'Need to support operator prim::ListUnpack'
)
# FIXME
def
test_time_sequence_prediction
(
self
):
class
Sequence
(
torch
.
jit
.
ScriptModule
):
def
__init__
(
self
):
super
(
Sequence
,
self
).
__init__
()
self
.
lstm1
=
nn
.
LSTMCell
(
1
,
51
)
self
.
lstm2
=
nn
.
LSTMCell
(
51
,
51
)
self
.
linear
=
nn
.
Linear
(
51
,
1
)
@
torch
.
jit
.
script_method
def
forward
(
self
,
input
):
# TODO: add future as input with default val
# see https://github.com/pytorch/pytorch/issues/8724
outputs
=
torch
.
empty
((
3
,
0
))
h_t
=
torch
.
zeros
((
3
,
51
))
c_t
=
torch
.
zeros
((
3
,
51
))
h_t2
=
torch
.
zeros
((
3
,
51
))
c_t2
=
torch
.
zeros
((
3
,
51
))
output
=
torch
.
zeros
([
3
,
51
])
future
=
2
# TODO: chunk call should appear as the for loop iterable
# We hard-code it to 4 for now.
a
,
b
,
c
,
d
=
input
.
chunk
(
input
.
size
(
1
),
dim
=
1
)
for
input_t
in
(
a
,
b
,
c
,
d
):
h_t
,
c_t
=
self
.
lstm1
(
input_t
,
(
h_t
,
c_t
))
h_t2
,
c_t2
=
self
.
lstm2
(
h_t
,
(
h_t2
,
c_t2
))
output
=
self
.
linear
(
h_t2
)
outputs
=
torch
.
cat
((
outputs
,
output
),
1
)
for
_
in
range
(
future
):
# if we should predict the future
h_t
,
c_t
=
self
.
lstm1
(
output
,
(
h_t
,
c_t
))
h_t2
,
c_t2
=
self
.
lstm2
(
h_t
,
(
h_t2
,
c_t2
))
output
=
self
.
linear
(
h_t2
)
outputs
=
torch
.
cat
((
outputs
,
output
),
1
)
return
outputs
class
Traced
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Traced
,
self
).
__init__
()
self
.
seq
=
Sequence
()
def
forward
(
self
,
input
):
return
self
.
seq
.
forward
(
input
)
self
.
checkExportImport
(
Traced
(),
(
torch
.
rand
(
3
,
4
),))
@
unittest
.
skip
(
'Unsupported callmethod encode'
)
# FIXME
def
test_vae
(
self
):
class
VAE
(
nn
.
Module
):
def
__init__
(
self
):
super
(
VAE
,
self
).
__init__
()
self
.
fc1
=
nn
.
Linear
(
784
,
400
)
self
.
fc21
=
nn
.
Linear
(
400
,
20
)
self
.
fc22
=
nn
.
Linear
(
400
,
20
)
self
.
fc3
=
nn
.
Linear
(
20
,
400
)
self
.
fc4
=
nn
.
Linear
(
400
,
784
)
def
encode
(
self
,
x
):
h1
=
F
.
relu
(
self
.
fc1
(
x
))
return
self
.
fc21
(
h1
),
self
.
fc22
(
h1
)
def
reparameterize
(
self
,
mu
,
logvar
):
if
self
.
training
:
std
=
torch
.
exp
(
0.5
*
logvar
)
eps
=
torch
.
randn_like
(
std
)
return
eps
.
mul
(
std
).
add_
(
mu
)
else
:
return
mu
def
decode
(
self
,
z
):
h3
=
F
.
relu
(
self
.
fc3
(
z
))
return
torch
.
sigmoid
(
self
.
fc4
(
h3
))
def
forward
(
self
,
x
):
mu
,
logvar
=
self
.
encode
(
x
.
view
(
-
1
,
784
))
z
=
self
.
reparameterize
(
mu
,
logvar
)
return
self
.
decode
(
z
),
mu
,
logvar
self
.
checkExportImport
(
VAE
().
eval
(),
(
torch
.
rand
(
128
,
1
,
28
,
28
),))
@
unittest
.
skip
(
'torchvision models are not supported yet'
)
# FIXME
def
test_torchvision_resnet18
(
self
):
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
(),
(
torch
.
ones
(
1
,
3
,
224
,
224
),))
@
unittest
.
skip
(
'Unsupported CallMethod _forward_impl'
)
# FIXME
def
test_resnet
(
self
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
"""1x1 convolution"""
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
)
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
):
"""3x3 convolution with padding"""
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias
=
False
)
class
BasicBlock
(
torch
.
jit
.
ScriptModule
):
expansion
=
1
__constants__
=
[
'downsample'
]
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
)
self
.
bn1
=
nn
.
BatchNorm2d
(
planes
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
conv2
=
conv3x3
(
planes
,
planes
)
self
.
bn2
=
nn
.
BatchNorm2d
(
planes
)
self
.
downsample
=
downsample
self
.
stride
=
stride
@
torch
.
jit
.
script_method
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
out
+=
residual
out
=
self
.
relu
(
out
)
return
out
class
ResNet
(
torch
.
jit
.
ScriptModule
):
__constants__
=
[
'layer1'
,
'layer2'
,
'layer3'
,
'layer4'
]
def
__init__
(
self
,
block
,
layers
,
num_classes
=
1000
):
super
(
ResNet
,
self
).
__init__
()
self
.
inplanes
=
64
self
.
conv1
=
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
64
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
layer1
=
self
.
_make_layer
(
block
,
64
,
layers
[
0
])
self
.
layer2
=
self
.
_make_layer
(
block
,
128
,
layers
[
1
],
stride
=
2
)
self
.
layer3
=
self
.
_make_layer
(
block
,
256
,
layers
[
2
],
stride
=
2
)
self
.
layer4
=
self
.
_make_layer
(
block
,
512
,
layers
[
3
],
stride
=
2
)
self
.
avgpool
=
nn
.
AdaptiveAvgPool2d
((
1
,
1
))
self
.
fc
=
nn
.
Linear
(
512
*
block
.
expansion
,
num_classes
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
torch
.
nn
.
init
.
kaiming_normal_
(
m
.
weight
,
mode
=
'fan_out'
,
nonlinearity
=
'relu'
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
torch
.
nn
.
init
.
constant_
(
m
.
weight
,
1
)
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
_make_layer
(
self
,
block
,
planes
,
blocks
,
stride
=
1
):
downsample
=
None
if
stride
!=
1
or
self
.
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
conv1x1
(
self
.
inplanes
,
planes
*
block
.
expansion
,
stride
),
nn
.
BatchNorm2d
(
planes
*
block
.
expansion
),
)
layers
=
[]
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
stride
,
downsample
))
self
.
inplanes
=
planes
*
block
.
expansion
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
self
.
inplanes
,
planes
))
return
nn
.
Sequential
(
*
layers
)
@
torch
.
jit
.
script_method
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool
(
x
)
x
=
self
.
layer1
(
x
)
x
=
self
.
layer2
(
x
)
x
=
self
.
layer3
(
x
)
x
=
self
.
layer4
(
x
)
x
=
self
.
avgpool
(
x
)
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
x
=
self
.
fc
(
x
)
return
x
resnet18
=
ResNet
(
BasicBlock
,
[
2
,
2
,
2
,
2
])
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
(),
(
torch
.
randn
(
1
,
3
,
224
,
224
),))
@
unittest
.
skip
(
'torchvision models are not supported yet'
)
# FIXME
def
test_alexnet
(
self
):
x
=
torch
.
ones
(
1
,
3
,
224
,
224
)
model
=
torchvision
.
models
.
AlexNet
()
self
.
checkExportImport
(
model
,
(
x
,))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment