Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
204392e5
Unverified
Commit
204392e5
authored
Mar 31, 2021
by
msbaines
Committed by
GitHub
Mar 31, 2021
Browse files
[refactor] multiprocess_pipe: only support torch >= 1.9.0 (#561)
parent
34384e1b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
43 deletions
+15
-43
fairscale/experimental/nn/multiprocess_pipe.py
fairscale/experimental/nn/multiprocess_pipe.py
+6
-15
tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_multiprocess_pipe.py
+9
-28
No files found.
fairscale/experimental/nn/multiprocess_pipe.py
View file @
204392e5
...
@@ -23,11 +23,6 @@ if TYPE_CHECKING:
...
@@ -23,11 +23,6 @@ if TYPE_CHECKING:
else
:
else
:
Module
=
nn
.
Module
Module
=
nn
.
Module
if
torch
.
__version__
.
split
(
"+"
)[
0
].
split
(
"."
)[:
3
]
<=
[
"1"
,
"8"
,
"1"
]:
BOUNCE_TENSORS
=
True
else
:
BOUNCE_TENSORS
=
False
def
_verify_module
(
module
:
List
[
LayerSpec
])
->
None
:
def
_verify_module
(
module
:
List
[
LayerSpec
])
->
None
:
if
not
isinstance
(
module
,
List
):
if
not
isinstance
(
module
,
List
):
...
@@ -54,10 +49,7 @@ class _ToHere(Module):
...
@@ -54,10 +49,7 @@ class _ToHere(Module):
self
.
device
=
device
self
.
device
=
device
def
forward
(
self
,
x_rref
:
rpc
.
RRef
)
->
Tensor
:
# type: ignore
def
forward
(
self
,
x_rref
:
rpc
.
RRef
)
->
Tensor
:
# type: ignore
if
BOUNCE_TENSORS
:
return
x_rref
.
to_here
()
return
x_rref
.
remote
().
cpu
().
to_here
().
to
(
self
.
device
)
else
:
return
x_rref
.
to_here
()
def
_create_sequential
(
layer_spec
:
List
[
LayerSpec
],
device
:
str
)
->
Module
:
def
_create_sequential
(
layer_spec
:
List
[
LayerSpec
],
device
:
str
)
->
Module
:
...
@@ -80,18 +72,15 @@ def _parameter_rrefs(module: rpc.RRef) -> List[rpc.RRef]:
...
@@ -80,18 +72,15 @@ def _parameter_rrefs(module: rpc.RRef) -> List[rpc.RRef]:
return
[
rpc
.
RRef
(
p
)
for
p
in
module
.
local_value
().
parameters
()]
return
[
rpc
.
RRef
(
p
)
for
p
in
module
.
local_value
().
parameters
()]
def
rloss
(
loss_func
:
Callable
,
input_rref
:
rpc
.
RRef
,
target_rref
:
rpc
.
RRef
)
->
rpc
.
RRef
:
def
_rloss
(
loss_func
:
Callable
,
input_rref
:
rpc
.
RRef
,
target_rref
:
rpc
.
RRef
)
->
rpc
.
RRef
:
if
BOUNCE_TENSORS
:
return
loss_func
(
input_rref
.
to_here
(),
target_rref
.
to_here
())
return
loss_func
(
input_rref
.
remote
().
cpu
().
to_here
(),
target_rref
.
remote
().
cpu
().
to_here
())
else
:
return
loss_func
(
input_rref
.
to_here
(),
target_rref
.
to_here
())
def
DistributedLoss
(
loss
:
nn
.
Module
,
*
args
:
Tuple
,
**
kwargs
:
Dict
)
->
Callable
:
def
DistributedLoss
(
loss
:
nn
.
Module
,
*
args
:
Tuple
,
**
kwargs
:
Dict
)
->
Callable
:
loss_func
=
loss
(
*
args
,
**
kwargs
)
loss_func
=
loss
(
*
args
,
**
kwargs
)
def
dloss
(
input_rref
:
rpc
.
RRef
,
target_rref
:
rpc
.
RRef
)
->
rpc
.
RRef
:
def
dloss
(
input_rref
:
rpc
.
RRef
,
target_rref
:
rpc
.
RRef
)
->
rpc
.
RRef
:
return
rpc
.
remote
(
input_rref
.
owner
(),
rloss
,
args
=
(
loss_func
,
input_rref
,
target_rref
))
return
rpc
.
remote
(
input_rref
.
owner
(),
_
rloss
,
args
=
(
loss_func
,
input_rref
,
target_rref
))
return
dloss
return
dloss
...
@@ -164,6 +153,8 @@ class MultiProcessPipe(Module):
...
@@ -164,6 +153,8 @@ class MultiProcessPipe(Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
torch
.
__version__
.
split
(
"."
)[:
2
]
<
[
"1"
,
"9"
]:
raise
RuntimeError
(
"MultiProcessPipe requires torch >= 1.9.0"
)
if
type
(
chunks
)
is
not
int
or
chunks
<=
0
:
if
type
(
chunks
)
is
not
int
or
chunks
<=
0
:
raise
ValueError
(
"number of chunks must be positive integer"
)
raise
ValueError
(
"number of chunks must be positive integer"
)
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
...
...
tests/experimental/nn/test_multiprocess_pipe.py
View file @
204392e5
...
@@ -21,11 +21,6 @@ import torch.nn as nn
...
@@ -21,11 +21,6 @@ import torch.nn as nn
from
fairscale.experimental.nn.multiprocess_pipe
import
DistributedLoss
,
MultiProcessPipe
from
fairscale.experimental.nn.multiprocess_pipe
import
DistributedLoss
,
MultiProcessPipe
from
fairscale.utils.testing
import
torch_version
from
fairscale.utils.testing
import
torch_version
if
torch_version
()
<=
(
1
,
8
,
1
):
BOUNCE_TENSORS
=
True
else
:
BOUNCE_TENSORS
=
False
CPU_DEVICES
=
[
"worker0/cpu"
,
"worker1/cpu"
]
CPU_DEVICES
=
[
"worker0/cpu"
,
"worker1/cpu"
]
GPU_DEVICES
=
[
"worker0/cuda:0"
,
"worker1/cuda:1"
]
GPU_DEVICES
=
[
"worker0/cuda:0"
,
"worker1/cuda:1"
]
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
@@ -33,26 +28,13 @@ if torch.cuda.is_available():
...
@@ -33,26 +28,13 @@ if torch.cuda.is_available():
else
:
else
:
DEVICES
=
[
CPU_DEVICES
]
DEVICES
=
[
CPU_DEVICES
]
pytestmark
=
pytest
.
mark
.
skipif
(
torch_version
()
<
(
1
,
8
,
0
),
reason
=
"requires torch version >= 1.
8
.0"
)
pytestmark
=
pytest
.
mark
.
skipif
(
torch_version
()
<
(
1
,
9
,
0
),
reason
=
"requires torch version >= 1.
9
.0"
)
def
rpc_worker
(
rank
,
world_size
,
init_file
,
func
,
*
args
):
def
rpc_worker
(
rank
,
world_size
,
init_file
,
func
,
*
args
):
if
torch_version
()
==
(
1
,
8
,
0
):
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
)
if
torch
.
cuda
.
is_available
():
for
i
in
range
(
world_size
):
# Workaround for https://github.com/pytorch/pytorch/issues/53844
options
.
set_device_map
(
"worker"
+
str
(
i
),
{
rank
:
i
})
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
,
_transports
=
[
"ibv"
,
"uv"
])
else
:
# Workaround for https://github.com/pytorch/pytorch/issues/54266
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
,
_channels
=
[
"mpt_uv"
,
"basic"
,
"cuda_ipc"
,
"cuda_gdr"
,
"cuda_xth"
,
"cuda_basic"
],
)
else
:
options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://"
+
init_file
)
if
torch_version
()
>
(
1
,
8
,
1
):
for
i
in
range
(
world_size
):
if
i
!=
rank
:
options
.
set_device_map
(
"worker"
+
str
(
i
),
{
rank
:
i
})
rpc
.
init_rpc
(
rpc
.
init_rpc
(
"worker"
+
str
(
rank
),
"worker"
+
str
(
rank
),
rank
=
rank
,
rank
=
rank
,
...
@@ -109,8 +91,9 @@ def parameter_rrefs(devices):
...
@@ -109,8 +91,9 @@ def parameter_rrefs(devices):
@
rpc_test
(
world_size
=
1
)
@
rpc_test
(
world_size
=
1
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
def
forward
(
devices
):
def
forward
(
devices
):
device
=
devices
[
0
].
split
(
"/"
)[
1
]
yh
=
torch
.
tensor
([
1.0
,
0.0
])
yh
=
torch
.
tensor
([
1.0
,
0.0
])
x
=
torch
.
tensor
([
1.0
,
-
1.0
])
x
=
torch
.
tensor
([
1.0
,
-
1.0
])
.
to
(
device
)
model
=
[(
"relu"
,
nn
.
ReLU
,
(),
{})]
model
=
[(
"relu"
,
nn
.
ReLU
,
(),
{})]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
chunks
=
1
,
devices
=
devices
[:
1
])
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
chunks
=
1
,
devices
=
devices
[:
1
])
y
=
pipe
(
x
).
to_here
().
cpu
()
y
=
pipe
(
x
).
to_here
().
cpu
()
...
@@ -120,8 +103,9 @@ def forward(devices):
...
@@ -120,8 +103,9 @@ def forward(devices):
@
rpc_test
(
world_size
=
1
)
@
rpc_test
(
world_size
=
1
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"devices"
,
DEVICES
)
def
forward_chunks
(
devices
):
def
forward_chunks
(
devices
):
device
=
devices
[
0
].
split
(
"/"
)[
1
]
yh
=
torch
.
tensor
([
1.0
,
0.0
,
2.0
,
0.0
,
3.0
,
0.0
,
4.0
,
0.0
])
yh
=
torch
.
tensor
([
1.0
,
0.0
,
2.0
,
0.0
,
3.0
,
0.0
,
4.0
,
0.0
])
x
=
torch
.
tensor
([
1.0
,
-
1.0
,
2.0
,
-
2.0
,
3.0
,
-
3.0
,
4.0
,
-
4.0
])
x
=
torch
.
tensor
([
1.0
,
-
1.0
,
2.0
,
-
2.0
,
3.0
,
-
3.0
,
4.0
,
-
4.0
])
.
to
(
device
)
model
=
[(
"relu"
,
nn
.
ReLU
,
(),
{})]
model
=
[(
"relu"
,
nn
.
ReLU
,
(),
{})]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
chunks
=
4
,
devices
=
devices
[:
1
])
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
chunks
=
4
,
devices
=
devices
[:
1
])
y
=
pipe
(
x
).
to_here
().
cpu
()
y
=
pipe
(
x
).
to_here
().
cpu
()
...
@@ -139,10 +123,7 @@ def forward_multi(devices, checkpoint):
...
@@ -139,10 +123,7 @@ def forward_multi(devices, checkpoint):
x
.
requires_grad
=
True
# TODO(msb) remove this limitation
x
.
requires_grad
=
True
# TODO(msb) remove this limitation
model
=
[(
"linear1"
,
nn
.
Linear
,
(
4
,
4
),
{}),
(
"relu"
,
nn
.
ReLU
,
(),
{})]
model
=
[(
"linear1"
,
nn
.
Linear
,
(
4
,
4
),
{}),
(
"relu"
,
nn
.
ReLU
,
(),
{})]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
,
1
],
chunks
=
4
,
devices
=
devices
[:
2
],
checkpoint
=
checkpoint
)
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
,
1
],
chunks
=
4
,
devices
=
devices
[:
2
],
checkpoint
=
checkpoint
)
if
BOUNCE_TENSORS
:
y
=
pipe
(
x
).
to_here
()
y
=
pipe
(
x
).
remote
().
cpu
().
to_here
()
else
:
y
=
pipe
(
x
).
to_here
()
expected_sum
=
torch
.
tensor
(
5.0615
)
expected_sum
=
torch
.
tensor
(
5.0615
)
assert
y
.
shape
==
torch
.
Size
([
8
,
4
])
assert
y
.
shape
==
torch
.
Size
([
8
,
4
])
assert
y
.
requires_grad
is
True
assert
y
.
requires_grad
is
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment