Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
204392e5
Unverified
Commit
204392e5
authored
Mar 31, 2021
by
msbaines
Committed by
GitHub
Mar 31, 2021
Browse files
[refactor] multiprocess_pipe: only support torch >= 1.9.0 (#561)
parent
34384e1b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
43 deletions
+15
-43
fairscale/experimental/nn/multiprocess_pipe.py
fairscale/experimental/nn/multiprocess_pipe.py
+6
-15
tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_multiprocess_pipe.py
+9
-28
No files found.
fairscale/experimental/nn/multiprocess_pipe.py
View file @
204392e5
...
@@ -23,11 +23,6 @@ if TYPE_CHECKING:
...
@@ -23,11 +23,6 @@ if TYPE_CHECKING:
else
:
else
:
Module
=
nn
.
Module
Module
=
nn
.
Module
if
torch
.
__version__
.
split
(
"+"
)[
0
].
split
(
"."
)[:
3
]
<=
[
"1"
,
"8"
,
"1"
]:
BOUNCE_TENSORS
=
True
else
:
BOUNCE_TENSORS
=
False
def
_verify_module
(
module
:
List
[
LayerSpec
])
->
None
:
def
_verify_module
(
module
:
List
[
LayerSpec
])
->
None
:
if
not
isinstance
(
module
,
List
):
if
not
isinstance
(
module
,
List
):
...
@@ -54,9 +49,6 @@ class _ToHere(Module):
...
@@ -54,9 +49,6 @@ class _ToHere(Module):
self
.
device
=
device
self
.
device
=
device
def
forward
(
self
,
x_rref
:
rpc
.
RRef
)
->
Tensor
:
# type: ignore
def
forward
(
self
,
x_rref
:
rpc
.
RRef
)
->
Tensor
:
# type: ignore
if
BOUNCE_TENSORS
:
return
x_rref
.
remote
().
cpu
().
to_here
().
to
(
self
.
device
)
else
:
return
x_rref
.
to_here
()
return
x_rref
.
to_here
()
...
@@ -80,10 +72,7 @@ def _parameter_rrefs(module: rpc.RRef) -> List[rpc.RRef]:
...
@@ -80,10 +72,7 @@ def _parameter_rrefs(module: rpc.RRef) -> List[rpc.RRef]:
return
[
rpc
.
RRef
(
p
)
for
p
in
module
.
local_value
().
parameters
()]
return
[
rpc
.
RRef
(
p
)
for
p
in
module
.
local_value
().
parameters
()]
def
rloss
(
loss_func
:
Callable
,
input_rref
:
rpc
.
RRef
,
target_rref
:
rpc
.
RRef
)
->
rpc
.
RRef
:
def
_rloss
(
loss_func
:
Callable
,
input_rref
:
rpc
.
RRef
,
target_rref
:
rpc
.
RRef
)
->
rpc
.
RRef
:
if
BOUNCE_TENSORS
:
return
loss_func
(
input_rref
.
remote
().
cpu
().
to_here
(),
target_rref
.
remote
().
cpu
().
to_here
())
else
:
return
loss_func
(
input_rref
.
to_here
(),
target_rref
.
to_here
())
return
loss_func
(
input_rref
.
to_here
(),
target_rref
.
to_here
())
...
@@ -91,7 +80,7 @@ def DistributedLoss(loss: nn.Module, *args: Tuple, **kwargs: Dict) -> Callable:
...
@@ -91,7 +80,7 @@ def DistributedLoss(loss: nn.Module, *args: Tuple, **kwargs: Dict) -> Callable:
loss_func
=
loss
(
*
args
,
**
kwargs
)
loss_func
=
loss
(
*
args
,
**
kwargs
)
def
dloss
(
input_rref
:
rpc
.
RRef
,
target_rref
:
rpc
.
RRef
)
->
rpc
.
RRef
:
def
dloss
(
input_rref
:
rpc
.
RRef
,
target_rref
:
rpc
.
RRef
)
->
rpc
.
RRef
:
return
rpc
.
remote
(
input_rref
.
owner
(),
rloss
,
args
=
(
loss_func
,
input_rref
,
target_rref
))
return
rpc
.
remote
(
input_rref
.
owner
(),
_
rloss
,
args
=
(
loss_func
,
input_rref
,
target_rref
))
return
dloss
return
dloss
...
@@ -164,6 +153,8 @@ class MultiProcessPipe(Module):
...
@@ -164,6 +153,8 @@ class MultiProcessPipe(Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
torch
.
__version__
.
split
(
"."
)[:
2
]
<
[
"1"
,
"9"
]:
raise
RuntimeError
(
"MultiProcessPipe requires torch >= 1.9.0"
)
if
type
(
chunks
)
is
not
int
or
chunks
<=
0
:
if
type
(
chunks
)
is
not
int
or
chunks
<=
0
:
raise
ValueError
(
"number of chunks must be positive integer"
)
raise
ValueError
(
"number of chunks must be positive integer"
)
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
...
...
tests/experimental/nn/test_multiprocess_pipe.py
View file @
204392e5
...
@@ -21,11 +21,6 @@ import torch.nn as nn
...
@@ -21,11 +21,6 @@ import torch.nn as nn
from
fairscale.experimental.nn.multiprocess_pipe
import
DistributedLoss
,
MultiProcessPipe
from
fairscale.experimental.nn.multiprocess_pipe
import
DistributedLoss
,
MultiProcessPipe
from
fairscale.utils.testing
import
torch_version
from
fairscale.utils.testing
import
torch_version
if
torch_version
()
<=
(
1
,
8
,
1
):
BOUNCE_TENSORS
=
True
else
:
BOUNCE_TENSORS
=
False
CPU_DEVICES
=
[
"worker0/cpu"
,
"worker1/cpu"
]
CPU_DEVICES
=
[
"worker0/cpu"
,
"worker1/cpu"
]
GPU_DEVICES
=
[
"worker0/cuda:0"
,
"worker1/cuda:1"
]
GPU_DEVICES
=
[
"worker0/cuda:0"
,
"worker1/cuda:1"
]
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
@@ -33,25 +28,12 @@ if torch.cuda.is_available():
...
@@ -33,25 +28,12 @@ if torch.cuda.is_available():
else
:
else
:
DEVICES
=
[
CPU_DEVICES
]
DEVICES
=
[
CPU_DEVICES
]
pytestmark
=
pytest
.
mark
.
skipif
(
torch_version
()
<
(
1
,
8
,
0
),
reason
=
"requires torch version >= 1.
8
.0"
)
pytestmark
=
pytest
.
mark
.
skipif
(
torch_version
()
<
(
1
,
9
,
0
),
reason
=
"requires torch version >= 1.
9
.0"
)
def
rpc_worker
(
rank
,
world_size
,
init_file
,
func
,
*
args
):
def
rpc_worker
(
rank
,
world_size
,
init_file
,
func
,
*
args
):
if
torch_version
()
==
(
1
,
8
,
0
):
if
torch
.
cuda
.
is_available
():
# Workaround for https://github.com/pytorch/pytorch/issues/53844
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
,
_transports
=
[
"ibv"
,
"uv"
])
else
:
# Workaround for https://github.com/pytorch/pytorch/issues/54266
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
,
_channels
=
[
"mpt_uv"
,
"basic"
,
"cuda_ipc"
,
"cuda_gdr"
,
"cuda_xth"
,
"cuda_basic"
],
)
else
:
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
)
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
)
if
torch_version
()
>
(
1
,
8
,
1
):
for
i
in
range
(
world_size
):
for
i
in
range
(
world_size
):
if
i
!=
rank
:
options
.
set_device_map
(
"worker"
+
str
(
i
),
{
rank
:
i
})
options
.
set_device_map
(
"worker"
+
str
(
i
),
{
rank
:
i
})
rpc
.
init_rpc
(
rpc
.
init_rpc
(
"worker"
+
str
(
rank
),
"worker"
+
str
(
rank
),
...
@@ -109,8 +91,9 @@ def parameter_rrefs(devices):
...
@@ -109,8 +91,9 @@ def parameter_rrefs(devices):
@
rpc_test
(
world_size
=
1
)
@
rpc_test
(
world_size
=
1
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
def
forward
(
devices
):
def
forward
(
devices
):
device
=
devices
[
0
].
split
(
"/"
)[
1
]
yh
=
torch
.
tensor
([
1.0
,
0.0
])
yh
=
torch
.
tensor
([
1.0
,
0.0
])
x
=
torch
.
tensor
([
1.0
,
-
1.0
])
x
=
torch
.
tensor
([
1.0
,
-
1.0
])
.
to
(
device
)
model
=
[(
"relu"
,
nn
.
ReLU
,
(),
{})]
model
=
[(
"relu"
,
nn
.
ReLU
,
(),
{})]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
chunks
=
1
,
devices
=
devices
[:
1
])
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
chunks
=
1
,
devices
=
devices
[:
1
])
y
=
pipe
(
x
).
to_here
().
cpu
()
y
=
pipe
(
x
).
to_here
().
cpu
()
...
@@ -120,8 +103,9 @@ def forward(devices):
...
@@ -120,8 +103,9 @@ def forward(devices):
@
rpc_test
(
world_size
=
1
)
@
rpc_test
(
world_size
=
1
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
def
forward_chunks
(
devices
):
def
forward_chunks
(
devices
):
device
=
devices
[
0
].
split
(
"/"
)[
1
]
yh
=
torch
.
tensor
([
1.0
,
0.0
,
2.0
,
0.0
,
3.0
,
0.0
,
4.0
,
0.0
])
yh
=
torch
.
tensor
([
1.0
,
0.0
,
2.0
,
0.0
,
3.0
,
0.0
,
4.0
,
0.0
])
x
=
torch
.
tensor
([
1.0
,
-
1.0
,
2.0
,
-
2.0
,
3.0
,
-
3.0
,
4.0
,
-
4.0
])
x
=
torch
.
tensor
([
1.0
,
-
1.0
,
2.0
,
-
2.0
,
3.0
,
-
3.0
,
4.0
,
-
4.0
])
.
to
(
device
)
model
=
[(
"relu"
,
nn
.
ReLU
,
(),
{})]
model
=
[(
"relu"
,
nn
.
ReLU
,
(),
{})]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
chunks
=
4
,
devices
=
devices
[:
1
])
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
chunks
=
4
,
devices
=
devices
[:
1
])
y
=
pipe
(
x
).
to_here
().
cpu
()
y
=
pipe
(
x
).
to_here
().
cpu
()
...
@@ -139,9 +123,6 @@ def forward_multi(devices, checkpoint):
...
@@ -139,9 +123,6 @@ def forward_multi(devices, checkpoint):
x
.
requires_grad
=
True
# TODO(msb) remove this limitation
x
.
requires_grad
=
True
# TODO(msb) remove this limitation
model
=
[(
"linear1"
,
nn
.
Linear
,
(
4
,
4
),
{}),
(
"relu"
,
nn
.
ReLU
,
(),
{})]
model
=
[(
"linear1"
,
nn
.
Linear
,
(
4
,
4
),
{}),
(
"relu"
,
nn
.
ReLU
,
(),
{})]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
,
1
],
chunks
=
4
,
devices
=
devices
[:
2
],
checkpoint
=
checkpoint
)
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
,
1
],
chunks
=
4
,
devices
=
devices
[:
2
],
checkpoint
=
checkpoint
)
if
BOUNCE_TENSORS
:
y
=
pipe
(
x
).
remote
().
cpu
().
to_here
()
else
:
y
=
pipe
(
x
).
to_here
()
y
=
pipe
(
x
).
to_here
()
expected_sum
=
torch
.
tensor
(
5.0615
)
expected_sum
=
torch
.
tensor
(
5.0615
)
assert
y
.
shape
==
torch
.
Size
([
8
,
4
])
assert
y
.
shape
==
torch
.
Size
([
8
,
4
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment