Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
5e6a7a57
Unverified
Commit
5e6a7a57
authored
Mar 29, 2021
by
msbaines
Committed by
GitHub
Mar 29, 2021
Browse files
[feat] multiproces_pipe: add checkpoint support (#555)
parent
9a950651
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
7 deletions
+24
-7
fairscale/experimental/nn/multiprocess_pipe.py
fairscale/experimental/nn/multiprocess_pipe.py
+19
-5
stubs/torch/utils/checkpoint.pyi
stubs/torch/utils/checkpoint.pyi
+1
-0
tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_multiprocess_pipe.py
+4
-2
No files found.
fairscale/experimental/nn/multiprocess_pipe.py
View file @
5e6a7a57
...
...
@@ -11,6 +11,7 @@ import torch
from
torch
import
Tensor
import
torch.distributed.rpc
as
rpc
import
torch.nn
as
nn
from
torch.utils.checkpoint
import
checkpoint_sequential
Tensors
=
Tuple
[
Tensor
,
...]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
...
...
@@ -69,6 +70,12 @@ def _rcat(tensors: List) -> Tensor:
return
torch
.
cat
([
t
.
local_value
()
for
t
in
tensors
])
def
_rcheckpoint
(
rmodule
:
rpc
.
RRef
,
input_rref
:
rpc
.
RRef
)
->
TensorOrTensors
:
module
=
rmodule
.
local_value
()
input
=
module
[
0
](
input_rref
)
# calls _ToHere.forward
return
checkpoint_sequential
(
module
[
1
:],
1
,
input
)
def
_parameter_rrefs
(
module
:
rpc
.
RRef
)
->
List
[
rpc
.
RRef
]:
return
[
rpc
.
RRef
(
p
)
for
p
in
module
.
local_value
().
parameters
()]
...
...
@@ -159,8 +166,8 @@ class MultiProcessPipe(Module):
if
type
(
chunks
)
is
not
int
or
chunks
<=
0
:
raise
ValueError
(
"number of chunks must be positive integer"
)
if
checkpoint
not
in
[
"never"
]:
raise
ValueError
(
"checkpoint is not
yet implemented
"
)
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
raise
ValueError
(
"checkpoint is not
one of 'always', 'except_last', or 'never'
"
)
if
deferred_batch_norm
:
raise
ValueError
(
"deferred_batch_norm is not yet implemented"
)
if
len
(
balance
)
!=
len
(
devices
):
...
...
@@ -181,6 +188,9 @@ class MultiProcessPipe(Module):
workers
.
append
(
worker
)
rmodule
.
append
(
rlayer
)
# The micro-batch index where the checkpointing stops.
self
.
checkpoint_stop
=
{
"always"
:
chunks
,
"except_last"
:
chunks
-
1
,
"never"
:
0
}[
checkpoint
]
self
.
chunks
=
chunks
self
.
checkpoint
=
checkpoint
self
.
module
=
module
...
...
@@ -189,8 +199,12 @@ class MultiProcessPipe(Module):
def
forward
(
self
,
x
:
Tensor
)
->
rpc
.
RRef
:
# type: ignore
outputs
=
[]
for
chunk
in
x
.
chunk
(
self
.
chunks
):
for
i
,
chunk
in
enumerate
(
x
.
chunk
(
self
.
chunks
)
)
:
output
=
rpc
.
RRef
(
chunk
)
if
i
<
self
.
checkpoint_stop
:
for
rlayer
in
self
.
rmodule
:
output
=
rpc
.
remote
(
rlayer
.
owner
(),
_rcheckpoint
,
args
=
(
rlayer
,
output
))
else
:
for
rlayer
in
self
.
rmodule
:
output
=
rlayer
.
remote
().
forward
(
output
)
outputs
.
append
(
output
)
...
...
stubs/torch/utils/checkpoint.pyi
View file @
5e6a7a57
...
...
@@ -7,3 +7,4 @@ from torch.nn.modules.module import Module
def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ...
def checkpoint(function: Module, *args, **kwargs): ...
def check_backward_validity(inputs: Iterable[Any]): ...
def checkpoint_sequential(function: Module, segments: int, *args, **kwargs): ...
tests/experimental/nn/test_multiprocess_pipe.py
View file @
5e6a7a57
...
...
@@ -130,13 +130,15 @@ def forward_chunks(devices):
@
rpc_test
(
world_size
=
2
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
def
forward_multi
(
devices
):
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
def
forward_multi
(
devices
,
checkpoint
):
device
=
devices
[
0
].
split
(
"/"
)[
1
]
torch
.
random
.
manual_seed
(
3
)
torch
.
cuda
.
manual_seed_all
(
3
)
x
=
torch
.
randn
(
8
,
4
).
to
(
device
)
x
.
requires_grad
=
True
# TODO(msb) remove this limitation
model
=
[(
"linear1"
,
nn
.
Linear
,
(
4
,
4
),
{}),
(
"relu"
,
nn
.
ReLU
,
(),
{})]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
,
1
],
chunks
=
4
,
devices
=
devices
[:
2
])
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
,
1
],
chunks
=
4
,
devices
=
devices
[:
2
]
,
checkpoint
=
checkpoint
)
if
BOUNCE_TENSORS
:
y
=
pipe
(
x
).
remote
().
cpu
().
to_here
()
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment