Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
5e6a7a57
Unverified
Commit
5e6a7a57
authored
Mar 29, 2021
by
msbaines
Committed by
GitHub
Mar 29, 2021
Browse files
[feat] multiproces_pipe: add checkpoint support (#555)
parent
9a950651
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
7 deletions
+24
-7
fairscale/experimental/nn/multiprocess_pipe.py
fairscale/experimental/nn/multiprocess_pipe.py
+19
-5
stubs/torch/utils/checkpoint.pyi
stubs/torch/utils/checkpoint.pyi
+1
-0
tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_multiprocess_pipe.py
+4
-2
No files found.
fairscale/experimental/nn/multiprocess_pipe.py
View file @
5e6a7a57
...
...
@@ -11,6 +11,7 @@ import torch
from
torch
import
Tensor
import
torch.distributed.rpc
as
rpc
import
torch.nn
as
nn
from
torch.utils.checkpoint
import
checkpoint_sequential
Tensors
=
Tuple
[
Tensor
,
...]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
...
...
@@ -69,6 +70,12 @@ def _rcat(tensors: List) -> Tensor:
return
torch
.
cat
([
t
.
local_value
()
for
t
in
tensors
])
def
_rcheckpoint
(
rmodule
:
rpc
.
RRef
,
input_rref
:
rpc
.
RRef
)
->
TensorOrTensors
:
module
=
rmodule
.
local_value
()
input
=
module
[
0
](
input_rref
)
# calls _ToHere.forward
return
checkpoint_sequential
(
module
[
1
:],
1
,
input
)
def
_parameter_rrefs
(
module
:
rpc
.
RRef
)
->
List
[
rpc
.
RRef
]:
return
[
rpc
.
RRef
(
p
)
for
p
in
module
.
local_value
().
parameters
()]
...
...
@@ -159,8 +166,8 @@ class MultiProcessPipe(Module):
if
type
(
chunks
)
is
not
int
or
chunks
<=
0
:
raise
ValueError
(
"number of chunks must be positive integer"
)
if
checkpoint
not
in
[
"never"
]:
raise
ValueError
(
"checkpoint is not
yet implemented
"
)
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
raise
ValueError
(
"checkpoint is not
one of 'always', 'except_last', or 'never'
"
)
if
deferred_batch_norm
:
raise
ValueError
(
"deferred_batch_norm is not yet implemented"
)
if
len
(
balance
)
!=
len
(
devices
):
...
...
@@ -181,6 +188,9 @@ class MultiProcessPipe(Module):
workers
.
append
(
worker
)
rmodule
.
append
(
rlayer
)
# The micro-batch index where the checkpointing stops.
self
.
checkpoint_stop
=
{
"always"
:
chunks
,
"except_last"
:
chunks
-
1
,
"never"
:
0
}[
checkpoint
]
self
.
chunks
=
chunks
self
.
checkpoint
=
checkpoint
self
.
module
=
module
...
...
@@ -189,10 +199,14 @@ class MultiProcessPipe(Module):
def
forward
(
self
,
x
:
Tensor
)
->
rpc
.
RRef
:
# type: ignore
outputs
=
[]
for
chunk
in
x
.
chunk
(
self
.
chunks
):
for
i
,
chunk
in
enumerate
(
x
.
chunk
(
self
.
chunks
)
)
:
output
=
rpc
.
RRef
(
chunk
)
for
rlayer
in
self
.
rmodule
:
output
=
rlayer
.
remote
().
forward
(
output
)
if
i
<
self
.
checkpoint_stop
:
for
rlayer
in
self
.
rmodule
:
output
=
rpc
.
remote
(
rlayer
.
owner
(),
_rcheckpoint
,
args
=
(
rlayer
,
output
))
else
:
for
rlayer
in
self
.
rmodule
:
output
=
rlayer
.
remote
().
forward
(
output
)
outputs
.
append
(
output
)
return
rpc
.
remote
(
outputs
[
0
].
owner
(),
_rcat
,
args
=
(
outputs
,))
...
...
stubs/torch/utils/checkpoint.pyi
View file @
5e6a7a57
...
...
@@ -7,3 +7,4 @@ from torch.nn.modules.module import Module
def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ...
def checkpoint(function: Module, *args, **kwargs): ...
def check_backward_validity(inputs: Iterable[Any]): ...
def checkpoint_sequential(function: Module, segments: int, *args, **kwargs): ...
tests/experimental/nn/test_multiprocess_pipe.py
View file @
5e6a7a57
...
...
@@ -130,13 +130,15 @@ def forward_chunks(devices):
@
rpc_test
(
world_size
=
2
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
def
forward_multi
(
devices
):
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
def
forward_multi
(
devices
,
checkpoint
):
device
=
devices
[
0
].
split
(
"/"
)[
1
]
torch
.
random
.
manual_seed
(
3
)
torch
.
cuda
.
manual_seed_all
(
3
)
x
=
torch
.
randn
(
8
,
4
).
to
(
device
)
x
.
requires_grad
=
True
# TODO(msb) remove this limitation
model
=
[(
"linear1"
,
nn
.
Linear
,
(
4
,
4
),
{}),
(
"relu"
,
nn
.
ReLU
,
(),
{})]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
,
1
],
chunks
=
4
,
devices
=
devices
[:
2
])
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
,
1
],
chunks
=
4
,
devices
=
devices
[:
2
]
,
checkpoint
=
checkpoint
)
if
BOUNCE_TENSORS
:
y
=
pipe
(
x
).
remote
().
cpu
().
to_here
()
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment