Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
2478a9ad
Unverified
Commit
2478a9ad
authored
Feb 25, 2021
by
Min Xu
Committed by
GitHub
Feb 25, 2021
Browse files
[test] checkpoint: multiple input and output model test (#425)
parent
3b0717eb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
0 deletions
+57
-0
tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations.py
+57
-0
No files found.
tests/nn/misc/test_checkpoint_activations.py
View file @
2478a9ad
...
...
@@ -195,3 +195,60 @@ def test_offload_memory():
# Use print to collect all debugging info.
print
(
base
,
cpt
,
offload
)
assert
0
class
MultiinMultioutModel
(
nn
.
Module
):
"""Model used to check different inputs and outputs"""
def
__init__
(
self
,
multiout
=
False
,
checkpoint_config
=
0
):
super
().
__init__
()
torch
.
manual_seed
(
0
)
# make sure weights are deterministic.
self
.
multiout
=
multiout
self
.
conv1
=
nn
.
Sequential
(
nn
.
Conv2d
(
1
,
5
,
3
),
nn
.
ReLU
(),
nn
.
Conv2d
(
5
,
5
,
3
))
self
.
conv2
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
5
,
3
),
nn
.
ReLU
(),
nn
.
Conv2d
(
5
,
5
,
3
))
assert
0
<=
checkpoint_config
<=
3
if
checkpoint_config
&
1
:
self
.
conv1
=
checkpoint_wrapper
(
self
.
conv1
)
if
checkpoint_config
&
(
1
<<
1
):
self
.
conv2
=
checkpoint_wrapper
(
self
.
conv2
)
def
forward
(
self
,
x1
,
x2
=
None
):
out1
=
self
.
conv1
(
x1
)
out2
=
self
.
conv2
(
x2
)
if
self
.
multiout
:
return
out1
,
out2
return
out1
+
out2
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"multiout"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"checkpoint_config"
,
[
1
,
2
,
3
])
def
test_multiin_multiout
(
device
,
multiout
,
checkpoint_config
):
if
"cuda"
in
device
and
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
"test requires a GPU"
)
def
train
(
model
,
in1
,
in2
):
out
=
model
(
in1
,
x2
=
in2
)
if
isinstance
(
out
,
tuple
):
out
=
torch
.
cat
(
out
)
loss
=
out
.
sum
()
loss
.
backward
()
gnorm
=
torch
.
norm
(
torch
.
stack
([
torch
.
norm
(
p
.
grad
.
detach
())
for
p
in
model
.
parameters
()]))
return
{
"loss"
:
loss
.
item
(),
"gnorm"
:
gnorm
.
item
()}
in1
=
torch
.
rand
(
4
,
1
,
32
,
32
).
requires_grad_
(
True
)
in2
=
torch
.
rand
(
4
,
3
,
32
,
32
).
requires_grad_
(
True
)
model
=
MultiinMultioutModel
(
multiout
,
0
).
to
(
device
)
no_cpt
=
train
(
model
,
in1
.
to
(
device
),
in2
.
to
(
device
))
model
=
MultiinMultioutModel
(
multiout
,
checkpoint_config
).
to
(
device
)
cpt
=
train
(
model
,
in1
.
to
(
device
),
in2
.
to
(
device
))
for
key
in
[
"loss"
,
"gnorm"
]:
if
no_cpt
[
key
]
!=
cpt
[
key
]:
print
(
no_cpt
,
cpt
)
assert
0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment