Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
5e6a7a57
Unverified
Commit
5e6a7a57
authored
Mar 29, 2021
by
msbaines
Committed by
GitHub
Mar 29, 2021
Browse files
[feat] multiproces_pipe: add checkpoint support (#555)
parent
9a950651
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
7 deletions
+24
-7
fairscale/experimental/nn/multiprocess_pipe.py
fairscale/experimental/nn/multiprocess_pipe.py
+19
-5
stubs/torch/utils/checkpoint.pyi
stubs/torch/utils/checkpoint.pyi
+1
-0
tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_multiprocess_pipe.py
+4
-2
No files found.
fairscale/experimental/nn/multiprocess_pipe.py
View file @
5e6a7a57
...
@@ -11,6 +11,7 @@ import torch
...
@@ -11,6 +11,7 @@ import torch
from
torch
import
Tensor
from
torch
import
Tensor
import
torch.distributed.rpc
as
rpc
import
torch.distributed.rpc
as
rpc
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.utils.checkpoint
import
checkpoint_sequential
Tensors
=
Tuple
[
Tensor
,
...]
Tensors
=
Tuple
[
Tensor
,
...]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
...
@@ -69,6 +70,12 @@ def _rcat(tensors: List) -> Tensor:
...
@@ -69,6 +70,12 @@ def _rcat(tensors: List) -> Tensor:
return
torch
.
cat
([
t
.
local_value
()
for
t
in
tensors
])
return
torch
.
cat
([
t
.
local_value
()
for
t
in
tensors
])
def
_rcheckpoint
(
rmodule
:
rpc
.
RRef
,
input_rref
:
rpc
.
RRef
)
->
TensorOrTensors
:
module
=
rmodule
.
local_value
()
input
=
module
[
0
](
input_rref
)
# calls _ToHere.forward
return
checkpoint_sequential
(
module
[
1
:],
1
,
input
)
def
_parameter_rrefs
(
module
:
rpc
.
RRef
)
->
List
[
rpc
.
RRef
]:
def
_parameter_rrefs
(
module
:
rpc
.
RRef
)
->
List
[
rpc
.
RRef
]:
return
[
rpc
.
RRef
(
p
)
for
p
in
module
.
local_value
().
parameters
()]
return
[
rpc
.
RRef
(
p
)
for
p
in
module
.
local_value
().
parameters
()]
...
@@ -159,8 +166,8 @@ class MultiProcessPipe(Module):
...
@@ -159,8 +166,8 @@ class MultiProcessPipe(Module):
if
type
(
chunks
)
is
not
int
or
chunks
<=
0
:
if
type
(
chunks
)
is
not
int
or
chunks
<=
0
:
raise
ValueError
(
"number of chunks must be positive integer"
)
raise
ValueError
(
"number of chunks must be positive integer"
)
if
checkpoint
not
in
[
"never"
]:
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
raise
ValueError
(
"checkpoint is not
yet implemented
"
)
raise
ValueError
(
"checkpoint is not
one of 'always', 'except_last', or 'never'
"
)
if
deferred_batch_norm
:
if
deferred_batch_norm
:
raise
ValueError
(
"deferred_batch_norm is not yet implemented"
)
raise
ValueError
(
"deferred_batch_norm is not yet implemented"
)
if
len
(
balance
)
!=
len
(
devices
):
if
len
(
balance
)
!=
len
(
devices
):
...
@@ -181,6 +188,9 @@ class MultiProcessPipe(Module):
...
@@ -181,6 +188,9 @@ class MultiProcessPipe(Module):
workers
.
append
(
worker
)
workers
.
append
(
worker
)
rmodule
.
append
(
rlayer
)
rmodule
.
append
(
rlayer
)
# The micro-batch index where the checkpointing stops.
self
.
checkpoint_stop
=
{
"always"
:
chunks
,
"except_last"
:
chunks
-
1
,
"never"
:
0
}[
checkpoint
]
self
.
chunks
=
chunks
self
.
chunks
=
chunks
self
.
checkpoint
=
checkpoint
self
.
checkpoint
=
checkpoint
self
.
module
=
module
self
.
module
=
module
...
@@ -189,8 +199,12 @@ class MultiProcessPipe(Module):
...
@@ -189,8 +199,12 @@ class MultiProcessPipe(Module):
def
forward
(
self
,
x
:
Tensor
)
->
rpc
.
RRef
:
# type: ignore
def
forward
(
self
,
x
:
Tensor
)
->
rpc
.
RRef
:
# type: ignore
outputs
=
[]
outputs
=
[]
for
chunk
in
x
.
chunk
(
self
.
chunks
):
for
i
,
chunk
in
enumerate
(
x
.
chunk
(
self
.
chunks
)
)
:
output
=
rpc
.
RRef
(
chunk
)
output
=
rpc
.
RRef
(
chunk
)
if
i
<
self
.
checkpoint_stop
:
for
rlayer
in
self
.
rmodule
:
output
=
rpc
.
remote
(
rlayer
.
owner
(),
_rcheckpoint
,
args
=
(
rlayer
,
output
))
else
:
for
rlayer
in
self
.
rmodule
:
for
rlayer
in
self
.
rmodule
:
output
=
rlayer
.
remote
().
forward
(
output
)
output
=
rlayer
.
remote
().
forward
(
output
)
outputs
.
append
(
output
)
outputs
.
append
(
output
)
...
...
stubs/torch/utils/checkpoint.pyi
View file @
5e6a7a57
...
@@ -7,3 +7,4 @@ from torch.nn.modules.module import Module
...
@@ -7,3 +7,4 @@ from torch.nn.modules.module import Module
def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ...
def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ...
def checkpoint(function: Module, *args, **kwargs): ...
def checkpoint(function: Module, *args, **kwargs): ...
def check_backward_validity(inputs: Iterable[Any]): ...
def check_backward_validity(inputs: Iterable[Any]): ...
def checkpoint_sequential(function: Module, segments: int, *args, **kwargs): ...
tests/experimental/nn/test_multiprocess_pipe.py
View file @
5e6a7a57
...
@@ -130,13 +130,15 @@ def forward_chunks(devices):
...
@@ -130,13 +130,15 @@ def forward_chunks(devices):
@
rpc_test
(
world_size
=
2
)
@
rpc_test
(
world_size
=
2
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
def
forward_multi
(
devices
):
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
def
forward_multi
(
devices
,
checkpoint
):
device
=
devices
[
0
].
split
(
"/"
)[
1
]
device
=
devices
[
0
].
split
(
"/"
)[
1
]
torch
.
random
.
manual_seed
(
3
)
torch
.
random
.
manual_seed
(
3
)
torch
.
cuda
.
manual_seed_all
(
3
)
torch
.
cuda
.
manual_seed_all
(
3
)
x
=
torch
.
randn
(
8
,
4
).
to
(
device
)
x
=
torch
.
randn
(
8
,
4
).
to
(
device
)
x
.
requires_grad
=
True
# TODO(msb) remove this limitation
model
=
[(
"linear1"
,
nn
.
Linear
,
(
4
,
4
),
{}),
(
"relu"
,
nn
.
ReLU
,
(),
{})]
model
=
[(
"linear1"
,
nn
.
Linear
,
(
4
,
4
),
{}),
(
"relu"
,
nn
.
ReLU
,
(),
{})]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
,
1
],
chunks
=
4
,
devices
=
devices
[:
2
])
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
,
1
],
chunks
=
4
,
devices
=
devices
[:
2
]
,
checkpoint
=
checkpoint
)
if
BOUNCE_TENSORS
:
if
BOUNCE_TENSORS
:
y
=
pipe
(
x
).
remote
().
cpu
().
to_here
()
y
=
pipe
(
x
).
remote
().
cpu
().
to_here
()
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment