Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
204392e5
Unverified
Commit
204392e5
authored
Mar 31, 2021
by
msbaines
Committed by
GitHub
Mar 31, 2021
Browse files
[refactor] multiprocess_pipe: only support torch >= 1.9.0 (#561)
parent
34384e1b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
43 deletions
+15
-43
fairscale/experimental/nn/multiprocess_pipe.py
fairscale/experimental/nn/multiprocess_pipe.py
+6
-15
tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_multiprocess_pipe.py
+9
-28
No files found.
fairscale/experimental/nn/multiprocess_pipe.py
View file @
204392e5
...
...
@@ -23,11 +23,6 @@ if TYPE_CHECKING:
else
:
Module
=
nn
.
Module
if
torch
.
__version__
.
split
(
"+"
)[
0
].
split
(
"."
)[:
3
]
<=
[
"1"
,
"8"
,
"1"
]:
BOUNCE_TENSORS
=
True
else
:
BOUNCE_TENSORS
=
False
def
_verify_module
(
module
:
List
[
LayerSpec
])
->
None
:
if
not
isinstance
(
module
,
List
):
...
...
@@ -54,10 +49,7 @@ class _ToHere(Module):
self
.
device
=
device
def
forward
(
self
,
x_rref
:
rpc
.
RRef
)
->
Tensor
:
# type: ignore
if
BOUNCE_TENSORS
:
return
x_rref
.
remote
().
cpu
().
to_here
().
to
(
self
.
device
)
else
:
return
x_rref
.
to_here
()
return
x_rref
.
to_here
()
def
_create_sequential
(
layer_spec
:
List
[
LayerSpec
],
device
:
str
)
->
Module
:
...
...
@@ -80,18 +72,15 @@ def _parameter_rrefs(module: rpc.RRef) -> List[rpc.RRef]:
return
[
rpc
.
RRef
(
p
)
for
p
in
module
.
local_value
().
parameters
()]
def
rloss
(
loss_func
:
Callable
,
input_rref
:
rpc
.
RRef
,
target_rref
:
rpc
.
RRef
)
->
rpc
.
RRef
:
if
BOUNCE_TENSORS
:
return
loss_func
(
input_rref
.
remote
().
cpu
().
to_here
(),
target_rref
.
remote
().
cpu
().
to_here
())
else
:
return
loss_func
(
input_rref
.
to_here
(),
target_rref
.
to_here
())
def
_rloss
(
loss_func
:
Callable
,
input_rref
:
rpc
.
RRef
,
target_rref
:
rpc
.
RRef
)
->
rpc
.
RRef
:
return
loss_func
(
input_rref
.
to_here
(),
target_rref
.
to_here
())
def
DistributedLoss
(
loss
:
nn
.
Module
,
*
args
:
Tuple
,
**
kwargs
:
Dict
)
->
Callable
:
loss_func
=
loss
(
*
args
,
**
kwargs
)
def
dloss
(
input_rref
:
rpc
.
RRef
,
target_rref
:
rpc
.
RRef
)
->
rpc
.
RRef
:
return
rpc
.
remote
(
input_rref
.
owner
(),
rloss
,
args
=
(
loss_func
,
input_rref
,
target_rref
))
return
rpc
.
remote
(
input_rref
.
owner
(),
_
rloss
,
args
=
(
loss_func
,
input_rref
,
target_rref
))
return
dloss
...
...
@@ -164,6 +153,8 @@ class MultiProcessPipe(Module):
)
->
None
:
super
().
__init__
()
if
torch
.
__version__
.
split
(
"."
)[:
2
]
<
[
"1"
,
"9"
]:
raise
RuntimeError
(
"MultiProcessPipe requires torch >= 1.9.0"
)
if
type
(
chunks
)
is
not
int
or
chunks
<=
0
:
raise
ValueError
(
"number of chunks must be positive integer"
)
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
...
...
tests/experimental/nn/test_multiprocess_pipe.py
View file @
204392e5
...
...
@@ -21,11 +21,6 @@ import torch.nn as nn
from
fairscale.experimental.nn.multiprocess_pipe
import
DistributedLoss
,
MultiProcessPipe
from
fairscale.utils.testing
import
torch_version
if
torch_version
()
<=
(
1
,
8
,
1
):
BOUNCE_TENSORS
=
True
else
:
BOUNCE_TENSORS
=
False
CPU_DEVICES
=
[
"worker0/cpu"
,
"worker1/cpu"
]
GPU_DEVICES
=
[
"worker0/cuda:0"
,
"worker1/cuda:1"
]
if
torch
.
cuda
.
is_available
():
...
...
@@ -33,26 +28,13 @@ if torch.cuda.is_available():
else
:
DEVICES
=
[
CPU_DEVICES
]
pytestmark
=
pytest
.
mark
.
skipif
(
torch_version
()
<
(
1
,
8
,
0
),
reason
=
"requires torch version >= 1.
8
.0"
)
pytestmark
=
pytest
.
mark
.
skipif
(
torch_version
()
<
(
1
,
9
,
0
),
reason
=
"requires torch version >= 1.
9
.0"
)
def
rpc_worker
(
rank
,
world_size
,
init_file
,
func
,
*
args
):
if
torch_version
()
==
(
1
,
8
,
0
):
if
torch
.
cuda
.
is_available
():
# Workaround for https://github.com/pytorch/pytorch/issues/53844
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
,
_transports
=
[
"ibv"
,
"uv"
])
else
:
# Workaround for https://github.com/pytorch/pytorch/issues/54266
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
,
_channels
=
[
"mpt_uv"
,
"basic"
,
"cuda_ipc"
,
"cuda_gdr"
,
"cuda_xth"
,
"cuda_basic"
],
)
else
:
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
)
if
torch_version
()
>
(
1
,
8
,
1
):
for
i
in
range
(
world_size
):
if
i
!=
rank
:
options
.
set_device_map
(
"worker"
+
str
(
i
),
{
rank
:
i
})
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
)
for
i
in
range
(
world_size
):
options
.
set_device_map
(
"worker"
+
str
(
i
),
{
rank
:
i
})
rpc
.
init_rpc
(
"worker"
+
str
(
rank
),
rank
=
rank
,
...
...
@@ -109,8 +91,9 @@ def parameter_rrefs(devices):
@
rpc_test
(
world_size
=
1
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
def
forward
(
devices
):
device
=
devices
[
0
].
split
(
"/"
)[
1
]
yh
=
torch
.
tensor
([
1.0
,
0.0
])
x
=
torch
.
tensor
([
1.0
,
-
1.0
])
x
=
torch
.
tensor
([
1.0
,
-
1.0
])
.
to
(
device
)
model
=
[(
"relu"
,
nn
.
ReLU
,
(),
{})]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
chunks
=
1
,
devices
=
devices
[:
1
])
y
=
pipe
(
x
).
to_here
().
cpu
()
...
...
@@ -120,8 +103,9 @@ def forward(devices):
@
rpc_test
(
world_size
=
1
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
def
forward_chunks
(
devices
):
device
=
devices
[
0
].
split
(
"/"
)[
1
]
yh
=
torch
.
tensor
([
1.0
,
0.0
,
2.0
,
0.0
,
3.0
,
0.0
,
4.0
,
0.0
])
x
=
torch
.
tensor
([
1.0
,
-
1.0
,
2.0
,
-
2.0
,
3.0
,
-
3.0
,
4.0
,
-
4.0
])
x
=
torch
.
tensor
([
1.0
,
-
1.0
,
2.0
,
-
2.0
,
3.0
,
-
3.0
,
4.0
,
-
4.0
])
.
to
(
device
)
model
=
[(
"relu"
,
nn
.
ReLU
,
(),
{})]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
chunks
=
4
,
devices
=
devices
[:
1
])
y
=
pipe
(
x
).
to_here
().
cpu
()
...
...
@@ -139,10 +123,7 @@ def forward_multi(devices, checkpoint):
x
.
requires_grad
=
True
# TODO(msb) remove this limitation
model
=
[(
"linear1"
,
nn
.
Linear
,
(
4
,
4
),
{}),
(
"relu"
,
nn
.
ReLU
,
(),
{})]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
,
1
],
chunks
=
4
,
devices
=
devices
[:
2
],
checkpoint
=
checkpoint
)
if
BOUNCE_TENSORS
:
y
=
pipe
(
x
).
remote
().
cpu
().
to_here
()
else
:
y
=
pipe
(
x
).
to_here
()
y
=
pipe
(
x
).
to_here
()
expected_sum
=
torch
.
tensor
(
5.0615
)
assert
y
.
shape
==
torch
.
Size
([
8
,
4
])
assert
y
.
requires_grad
is
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment