Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
144e7567
Commit
144e7567
authored
Nov 18, 2020
by
dreamerlin
Browse files
use pytest.mark.parametrize
parent
86d9f468
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
203 additions
and
195 deletions
+203
-195
tests/test_cnn/test_wrappers.py
tests/test_cnn/test_wrappers.py
+203
-195
No files found.
tests/test_cnn/test_wrappers.py
View file @
144e7567
from
collections
import
OrderedDict
from
itertools
import
product
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -10,239 +9,248 @@ from mmcv.cnn.bricks import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear,
...
@@ -10,239 +9,248 @@ from mmcv.cnn.bricks import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear,
@
patch
(
'torch.__version__'
,
'1.1'
)
@
patch
(
'torch.__version__'
,
'1.1'
)
def
test_conv2d
():
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
[(
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
def
test_conv2d
(
in_w
,
in_h
,
in_channel
,
out_channel
,
kernel_size
,
stride
,
padding
,
dilation
):
"""
"""
CommandLine:
CommandLine:
xdoctest -m tests/test_wrappers.py test_conv2d
xdoctest -m tests/test_wrappers.py test_conv2d
"""
"""
test_cases
=
OrderedDict
([(
'in_w'
,
[
10
,
20
]),
(
'in_h'
,
[
10
,
20
]),
(
'in_channel'
,
[
1
,
3
]),
(
'out_channel'
,
[
1
,
3
]),
(
'kernel_size'
,
[
3
,
5
]),
(
'stride'
,
[
1
,
2
]),
(
'padding'
,
[
0
,
1
]),
(
'dilation'
,
[
1
,
2
])])
# train mode
# train mode
for
in_h
,
in_w
,
in_cha
,
out_cha
,
k
,
s
,
p
,
d
in
product
(
# wrapper op with 0-dim input
*
list
(
test_cases
.
values
())):
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_h
,
in_w
)
# wrapper op with 0-dim input
torch
.
manual_seed
(
0
)
x_empty
=
torch
.
randn
(
0
,
in_cha
,
in_h
,
in_w
)
wrapper
=
Conv2d
(
torch
.
manual_seed
(
0
)
in_channel
,
wrapper
=
Conv2d
(
in_cha
,
out_cha
,
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
)
out_channel
,
wrapper_out
=
wrapper
(
x_empty
)
kernel_size
,
stride
=
stride
,
# torch op with 3-dim input as shape reference
padding
=
padding
,
x_normal
=
torch
.
randn
(
3
,
in_cha
,
in_h
,
in_w
).
requires_grad_
(
True
)
dilation
=
dilation
)
torch
.
manual_seed
(
0
)
wrapper_out
=
wrapper
(
x_empty
)
ref
=
nn
.
Conv2d
(
in_cha
,
out_cha
,
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
)
ref_out
=
ref
(
x_normal
)
# torch op with 3-dim input as shape reference
x_normal
=
torch
.
randn
(
3
,
in_channel
,
in_h
,
in_w
).
requires_grad_
(
True
)
assert
wrapper_out
.
shape
[
0
]
==
0
torch
.
manual_seed
(
0
)
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
ref
=
nn
.
Conv2d
(
in_channel
,
wrapper_out
.
sum
().
backward
()
out_channel
,
assert
wrapper
.
weight
.
grad
is
not
None
kernel_size
,
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
stride
=
stride
,
padding
=
padding
,
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
dilation
=
dilation
)
ref_out
=
ref
(
x_normal
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
wrapper_out
.
sum
().
backward
()
assert
wrapper
.
weight
.
grad
is
not
None
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
# eval mode
# eval mode
x_empty
=
torch
.
randn
(
0
,
in_cha
,
in_h
,
in_w
)
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_h
,
in_w
)
wrapper
=
Conv2d
(
in_cha
,
out_cha
,
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
)
wrapper
=
Conv2d
(
in_channel
,
out_channel
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
wrapper
.
eval
()
wrapper
.
eval
()
wrapper
(
x_empty
)
wrapper
(
x_empty
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
patch
(
'torch.__version__'
,
'1.1'
)
def
test_conv_transposed_2d
():
@
pytest
.
mark
.
parametrize
(
test_cases
=
OrderedDict
([(
'in_w'
,
[
10
,
20
]),
(
'in_h'
,
[
10
,
20
]),
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
(
'in_channel'
,
[
1
,
3
]),
(
'out_channel'
,
[
1
,
3
]),
[(
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
(
'kernel_size'
,
[
3
,
5
]),
(
'stride'
,
[
1
,
2
]),
def
test_conv_transposed_2d
(
in_w
,
in_h
,
in_channel
,
out_channel
,
kernel_size
,
(
'padding'
,
[
0
,
1
]),
(
'dilation'
,
[
1
,
2
])])
stride
,
padding
,
dilation
):
# wrapper op with 0-dim input
for
in_h
,
in_w
,
in_cha
,
out_cha
,
k
,
s
,
p
,
d
in
product
(
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_h
,
in_w
,
requires_grad
=
True
)
*
list
(
test_cases
.
values
())):
# out padding must be smaller than either stride or dilation
# wrapper op with 0-dim input
op
=
min
(
stride
,
dilation
)
-
1
x_empty
=
torch
.
randn
(
0
,
in_cha
,
in_h
,
in_w
,
requires_grad
=
True
)
torch
.
manual_seed
(
0
)
# out padding must be smaller than either stride or dilation
wrapper
=
ConvTranspose2d
(
op
=
min
(
s
,
d
)
-
1
in_channel
,
torch
.
manual_seed
(
0
)
out_channel
,
wrapper
=
ConvTranspose2d
(
kernel_size
,
in_cha
,
stride
=
stride
,
out_cha
,
padding
=
padding
,
k
,
dilation
=
dilation
,
stride
=
s
,
output_padding
=
op
)
padding
=
p
,
wrapper_out
=
wrapper
(
x_empty
)
dilation
=
d
,
output_padding
=
op
)
# torch op with 3-dim input as shape reference
wrapper_out
=
wrapper
(
x_empty
)
x_normal
=
torch
.
randn
(
3
,
in_channel
,
in_h
,
in_w
)
torch
.
manual_seed
(
0
)
# torch op with 3-dim input as shape reference
ref
=
nn
.
ConvTranspose2d
(
x_normal
=
torch
.
randn
(
3
,
in_cha
,
in_h
,
in_w
)
in_channel
,
torch
.
manual_seed
(
0
)
out_channel
,
ref
=
nn
.
ConvTranspose2d
(
kernel_size
,
in_cha
,
stride
=
stride
,
out_cha
,
padding
=
padding
,
k
,
dilation
=
dilation
,
stride
=
s
,
output_padding
=
op
)
padding
=
p
,
ref_out
=
ref
(
x_normal
)
dilation
=
d
,
output_padding
=
op
)
assert
wrapper_out
.
shape
[
0
]
==
0
ref_out
=
ref
(
x_normal
)
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
assert
wrapper_out
.
shape
[
0
]
==
0
wrapper_out
.
sum
().
backward
()
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
assert
wrapper
.
weight
.
grad
is
not
None
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
wrapper_out
.
sum
().
backward
()
assert
wrapper
.
weight
.
grad
is
not
None
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
# eval mode
# eval mode
x_empty
=
torch
.
randn
(
0
,
in_cha
,
in_h
,
in_w
)
x_empty
=
torch
.
randn
(
0
,
in_cha
nnel
,
in_h
,
in_w
)
wrapper
=
ConvTranspose2d
(
wrapper
=
ConvTranspose2d
(
in_cha
,
out_cha
,
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
,
output_padding
=
op
)
in_channel
,
out_channel
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
output_padding
=
op
)
wrapper
.
eval
()
wrapper
.
eval
()
wrapper
(
x_empty
)
wrapper
(
x_empty
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
patch
(
'torch.__version__'
,
'1.1'
)
def
test_conv_transposed_3d
():
@
pytest
.
mark
.
parametrize
(
test_cases
=
OrderedDict
([(
'in_w'
,
[
10
,
20
]),
(
'in_h'
,
[
10
,
20
]),
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
# noqa: E501
(
'in_t'
,
[
10
,
20
]),
(
'in_channel'
,
[
1
,
3
]),
[(
10
,
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
(
'out_channel'
,
[
1
,
3
]),
(
'kernel_size'
,
[
3
,
5
]),
def
test_conv_transposed_3d
(
in_w
,
in_h
,
in_t
,
in_channel
,
out_channel
,
(
'stride'
,
[
1
,
2
]),
(
'padding'
,
[
0
,
1
]),
kernel_size
,
stride
,
padding
,
dilation
):
(
'dilation'
,
[
1
,
2
])])
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_t
,
in_h
,
in_w
,
requires_grad
=
True
)
for
in_h
,
in_w
,
in_t
,
in_cha
,
out_cha
,
k
,
s
,
p
,
d
in
product
(
# out padding must be smaller than either stride or dilation
*
list
(
test_cases
.
values
())):
op
=
min
(
stride
,
dilation
)
-
1
# wrapper op with 0-dim input
torch
.
manual_seed
(
0
)
x_empty
=
torch
.
randn
(
0
,
in_cha
,
in_t
,
in_h
,
in_w
,
requires_grad
=
True
)
wrapper
=
ConvTranspose3d
(
# out padding must be smaller than either stride or dilation
in_channel
,
op
=
min
(
s
,
d
)
-
1
out_channel
,
torch
.
manual_seed
(
0
)
kernel_size
,
wrapper
=
ConvTranspose3d
(
stride
=
stride
,
in_cha
,
padding
=
padding
,
out_cha
,
dilation
=
dilation
,
k
,
output_padding
=
op
)
stride
=
s
,
wrapper_out
=
wrapper
(
x_empty
)
padding
=
p
,
dilation
=
d
,
# torch op with 3-dim input as shape reference
output_padding
=
op
)
x_normal
=
torch
.
randn
(
3
,
in_channel
,
in_t
,
in_h
,
in_w
)
wrapper_out
=
wrapper
(
x_empty
)
torch
.
manual_seed
(
0
)
ref
=
nn
.
ConvTranspose3d
(
# torch op with 3-dim input as shape reference
in_channel
,
x_normal
=
torch
.
randn
(
3
,
in_cha
,
in_t
,
in_h
,
in_w
)
out_channel
,
torch
.
manual_seed
(
0
)
kernel_size
,
ref
=
nn
.
ConvTranspose3d
(
stride
=
stride
,
in_cha
,
padding
=
padding
,
out_cha
,
dilation
=
dilation
,
k
,
output_padding
=
op
)
stride
=
s
,
ref_out
=
ref
(
x_normal
)
padding
=
p
,
dilation
=
d
,
assert
wrapper_out
.
shape
[
0
]
==
0
output_padding
=
op
)
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
ref_out
=
ref
(
x_normal
)
wrapper_out
.
sum
().
backward
()
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper
.
weight
.
grad
is
not
None
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
wrapper_out
.
sum
().
backward
()
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
assert
wrapper
.
weight
.
grad
is
not
None
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
# eval mode
# eval mode
x_empty
=
torch
.
randn
(
0
,
in_cha
,
in_t
,
in_h
,
in_w
)
x_empty
=
torch
.
randn
(
0
,
in_cha
nnel
,
in_t
,
in_h
,
in_w
)
wrapper
=
ConvTranspose3d
(
wrapper
=
ConvTranspose3d
(
in_cha
,
out_cha
,
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
,
output_padding
=
op
)
in_channel
,
out_channel
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
output_padding
=
op
)
wrapper
.
eval
()
wrapper
.
eval
()
wrapper
(
x_empty
)
wrapper
(
x_empty
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
patch
(
'torch.__version__'
,
'1.1'
)
def
test_max_pool_2d
():
@
pytest
.
mark
.
parametrize
(
test_cases
=
OrderedDict
([(
'in_w'
,
[
10
,
20
]),
(
'in_h'
,
[
10
,
20
]),
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
(
'in_channel'
,
[
1
,
3
]),
(
'out_channel'
,
[
1
,
3
]),
[(
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
(
'kernel_size'
,
[
3
,
5
]),
(
'stride'
,
[
1
,
2
]),
def
test_max_pool_2d
(
in_w
,
in_h
,
in_channel
,
out_channel
,
kernel_size
,
stride
,
(
'padding'
,
[
0
,
1
]),
(
'dilation'
,
[
1
,
2
])])
padding
,
dilation
):
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_h
,
in_w
,
requires_grad
=
True
)
wrapper
=
MaxPool2d
(
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
wrapper_out
=
wrapper
(
x_empty
)
for
in_h
,
in_w
,
in_cha
,
out_cha
,
k
,
s
,
p
,
d
in
product
(
# torch op with 3-dim input as shape reference
*
list
(
test_cases
.
values
())):
x_normal
=
torch
.
randn
(
3
,
in_channel
,
in_h
,
in_w
)
# wrapper op with 0-dim input
ref
=
nn
.
MaxPool2d
(
x_empty
=
torch
.
randn
(
0
,
in_cha
,
in_h
,
in_w
,
requires_grad
=
True
)
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
wrapper
=
MaxPool2d
(
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
)
ref_out
=
ref
(
x_normal
)
wrapper_out
=
wrapper
(
x_empty
)
# torch op with 3-dim input as shape reference
assert
wrapper_out
.
shape
[
0
]
==
0
x_normal
=
torch
.
randn
(
3
,
in_cha
,
in_h
,
in_w
)
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
ref
=
nn
.
MaxPool2d
(
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
)
ref_out
=
ref
(
x_normal
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
patch
(
'torch.__version__'
,
'1.1'
)
def
test_max_pool_3d
():
@
pytest
.
mark
.
parametrize
(
test_cases
=
OrderedDict
([(
'in_w'
,
[
10
,
20
]),
(
'in_h'
,
[
10
,
20
]),
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
# noqa: E501
(
'in_t'
,
[
10
,
20
]),
(
'in_channel'
,
[
1
,
3
]),
[(
10
,
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
(
'out_channel'
,
[
1
,
3
]),
(
'kernel_size'
,
[
3
,
5
]),
def
test_max_pool_3d
(
in_w
,
in_h
,
in_t
,
in_channel
,
out_channel
,
kernel_size
,
(
'stride'
,
[
1
,
2
]),
(
'padding'
,
[
0
,
1
]),
stride
,
padding
,
dilation
):
(
'dilation'
,
[
1
,
2
])])
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_t
,
in_h
,
in_w
,
requires_grad
=
True
)
for
in_h
,
in_w
,
in_t
,
in_cha
,
out_cha
,
k
,
s
,
p
,
d
in
product
(
wrapper
=
MaxPool3d
(
*
list
(
test_cases
.
values
())):
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
# wrapper op with 0-dim input
wrapper_out
=
wrapper
(
x_empty
)
x_empty
=
torch
.
randn
(
0
,
in_cha
,
in_t
,
in_h
,
in_w
,
requires_grad
=
True
)
wrapper
=
MaxPool3d
(
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
)
wrapper_out
=
wrapper
(
x_empty
)
# torch op with 3-dim input as shape reference
# torch op with 3-dim input as shape reference
x_normal
=
torch
.
randn
(
3
,
in_cha
,
in_t
,
in_h
,
in_w
)
x_normal
=
torch
.
randn
(
3
,
in_channel
,
in_t
,
in_h
,
in_w
)
ref
=
nn
.
MaxPool3d
(
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
)
ref
=
nn
.
MaxPool3d
(
ref_out
=
ref
(
x_normal
)
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
ref_out
=
ref
(
x_normal
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
patch
(
'torch.__version__'
,
'1.1'
)
def
test_linear
():
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_feature,out_feature'
,
[(
10
,
10
,
1
,
1
),
test_cases
=
OrderedDict
([
(
20
,
20
,
3
,
3
)])
(
'in_w'
,
[
10
,
20
]),
def
test_linear
(
in_w
,
in_h
,
in_feature
,
out_feature
):
(
'in_h'
,
[
10
,
20
]),
# wrapper op with 0-dim input
(
'in_feature'
,
[
1
,
3
]),
x_empty
=
torch
.
randn
(
0
,
in_feature
,
requires_grad
=
True
)
(
'out_feature'
,
[
1
,
3
]),
torch
.
manual_seed
(
0
)
])
wrapper
=
Linear
(
in_feature
,
out_feature
)
wrapper_out
=
wrapper
(
x_empty
)
for
in_h
,
in_w
,
in_feature
,
out_feature
in
product
(
*
list
(
test_cases
.
values
())):
# torch op with 3-dim input as shape reference
# wrapper op with 0-dim input
x_normal
=
torch
.
randn
(
3
,
in_feature
)
x_empty
=
torch
.
randn
(
0
,
in_feature
,
requires_grad
=
True
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
ref
=
nn
.
Linear
(
in_feature
,
out_feature
)
wrapper
=
Linear
(
in_feature
,
out_feature
)
ref_out
=
ref
(
x_normal
)
wrapper_out
=
wrapper
(
x_empty
)
assert
wrapper_out
.
shape
[
0
]
==
0
# torch op with 3-dim input as shape reference
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
x_normal
=
torch
.
randn
(
3
,
in_feature
)
torch
.
manual_seed
(
0
)
wrapper_out
.
sum
().
backward
()
ref
=
nn
.
Linear
(
in_feature
,
out_feature
)
assert
wrapper
.
weight
.
grad
is
not
None
ref_out
=
ref
(
x_normal
)
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
assert
wrapper_out
.
shape
[
0
]
==
0
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
wrapper_out
.
sum
().
backward
()
assert
wrapper
.
weight
.
grad
is
not
None
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
# eval mode
# eval mode
x_empty
=
torch
.
randn
(
0
,
in_feature
)
x_empty
=
torch
.
randn
(
0
,
in_feature
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment