Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
62635f0f
Unverified
Commit
62635f0f
authored
Mar 28, 2021
by
msbaines
Committed by
GitHub
Mar 28, 2021
Browse files
[feat] multiprocess_pipe: add support for testing gpu-gpu rpc (#552)
parent
9a6ca9bd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
7 deletions
+21
-7
fairscale/experimental/nn/multiprocess_pipe.py
fairscale/experimental/nn/multiprocess_pipe.py
+6
-3
stubs/torch/distributed/rpc/__init__.pyi
stubs/torch/distributed/rpc/__init__.pyi
+1
-0
tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_multiprocess_pipe.py
+14
-4
No files found.
fairscale/experimental/nn/multiprocess_pipe.py
View file @
62635f0f
...
...
@@ -22,7 +22,10 @@ if TYPE_CHECKING:
else
:
Module
=
nn
.
Module
BOUNCE_TENSORS
=
True
if
torch
.
__version__
.
split
(
"+"
)[
0
].
split
(
"."
)[:
3
]
<=
[
"1"
,
"8"
,
"1"
]:
BOUNCE_TENSORS
=
True
else
:
BOUNCE_TENSORS
=
False
def
_verify_module
(
module
:
List
[
LayerSpec
])
->
None
:
...
...
@@ -53,7 +56,7 @@ class _ToHere(Module):
if
BOUNCE_TENSORS
:
return
x_rref
.
remote
().
cpu
().
to_here
().
to
(
self
.
device
)
else
:
return
x_rref
.
to_here
()
.
to
(
self
.
device
)
return
x_rref
.
to_here
()
def
_create_sequential
(
layer_spec
:
List
[
LayerSpec
],
device
:
str
)
->
Module
:
...
...
@@ -67,7 +70,7 @@ def _rcat(tensors: List) -> Tensor:
def
_parameter_rrefs
(
module
:
rpc
.
RRef
)
->
List
[
rpc
.
RRef
]:
return
[
rpc
.
RRef
(
p
)
for
p
in
module
.
to_her
e
().
parameters
()]
return
[
rpc
.
RRef
(
p
)
for
p
in
module
.
local_valu
e
().
parameters
()]
def
rloss
(
loss_func
:
Callable
,
input_rref
:
rpc
.
RRef
,
target_rref
:
rpc
.
RRef
)
->
rpc
.
RRef
:
...
...
stubs/torch/distributed/rpc/__init__.pyi
View file @
62635f0f
...
...
@@ -5,6 +5,7 @@ from torch.futures import Future
class RRef:
def __init__(self, t: Any) -> None: ...
def local_value(self) -> Any: ...
def owner(self) -> WorkerInfo: ...
def remote(self) -> Any: ...
def rpc_sync(self) -> Any: ...
...
...
tests/experimental/nn/test_multiprocess_pipe.py
View file @
62635f0f
...
...
@@ -21,7 +21,10 @@ import torch.nn as nn
from
fairscale.experimental.nn.multiprocess_pipe
import
DistributedLoss
,
MultiProcessPipe
from
fairscale.utils.testing
import
torch_version
BOUNCE_TENSORS
=
True
if
torch_version
()
<=
(
1
,
8
,
1
):
BOUNCE_TENSORS
=
True
else
:
BOUNCE_TENSORS
=
False
CPU_DEVICES
=
[
"worker0/cpu"
,
"worker1/cpu"
]
GPU_DEVICES
=
[
"worker0/cuda:0"
,
"worker1/cuda:1"
]
...
...
@@ -46,6 +49,10 @@ def rpc_worker(rank, world_size, init_file, func, *args):
)
else
:
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
)
if
torch_version
()
>
(
1
,
8
,
1
):
for
i
in
range
(
world_size
):
if
i
!=
rank
:
options
.
set_device_map
(
"worker"
+
str
(
i
),
{
rank
:
i
})
rpc
.
init_rpc
(
"worker"
+
str
(
rank
),
rank
=
rank
,
...
...
@@ -124,9 +131,10 @@ def forward_chunks(devices):
@
rpc_test
(
world_size
=
2
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
def
forward_multi
(
devices
):
device
=
devices
[
0
].
split
(
"/"
)[
1
]
torch
.
random
.
manual_seed
(
3
)
torch
.
cuda
.
manual_seed_all
(
3
)
x
=
torch
.
randn
(
8
,
4
)
x
=
torch
.
randn
(
8
,
4
)
.
to
(
device
)
model
=
[(
"linear1"
,
nn
.
Linear
,
(
4
,
4
),
{}),
(
"relu"
,
nn
.
ReLU
,
(),
{})]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
,
1
],
chunks
=
4
,
devices
=
devices
[:
2
])
if
BOUNCE_TENSORS
:
...
...
@@ -142,9 +150,10 @@ def forward_multi(devices):
@
rpc_test
(
world_size
=
2
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
def
backward
(
devices
):
device
=
devices
[
0
].
split
(
"/"
)[
1
]
torch
.
random
.
manual_seed
(
3
)
criterion
=
DistributedLoss
(
torch
.
nn
.
MSELoss
)
x
=
torch
.
randn
(
8
,
4
)
x
=
torch
.
randn
(
8
,
4
)
.
to
(
device
)
model
=
[(
"linear1"
,
nn
.
Linear
,
(
4
,
4
),
{}),
(
"relu"
,
nn
.
ReLU
,
(),
{})]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
,
1
],
chunks
=
4
,
devices
=
devices
[:
2
])
with
dist_autograd
.
context
()
as
context_id
:
...
...
@@ -158,9 +167,10 @@ def backward(devices):
@
rpc_test
(
world_size
=
2
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
def
update
(
devices
):
device
=
devices
[
0
].
split
(
"/"
)[
1
]
torch
.
random
.
manual_seed
(
3
)
criterion
=
DistributedLoss
(
torch
.
nn
.
MSELoss
)
x
=
torch
.
randn
(
8
,
4
)
x
=
torch
.
randn
(
8
,
4
)
.
to
(
device
)
model
=
[(
"linear1"
,
nn
.
Linear
,
(
4
,
4
),
{}),
(
"relu"
,
nn
.
ReLU
,
(),
{})]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
,
1
],
chunks
=
4
,
devices
=
devices
[:
2
])
params
=
pipe
.
parameter_rrefs
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment