Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
eaee5976
"vscode:/vscode.git/clone" did not exist on "d9effbd1d0da393b46ee4524e8ce8f52245e9bba"
Unverified
Commit
eaee5976
authored
Jan 29, 2021
by
msbaines
Committed by
GitHub
Jan 29, 2021
Browse files
[refactor] make AsyncPipe its own class (#341)
parent
51625eda
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
200 additions
and
286 deletions
+200
-286
benchmarks/experimental_ampnet.py
benchmarks/experimental_ampnet.py
+1
-2
benchmarks/pipe.py
benchmarks/pipe.py
+0
-1
experimental/nn/ampnet_pipe/ampnet.py
experimental/nn/ampnet_pipe/ampnet.py
+0
-1
experimental/nn/ampnet_pipe/pipe.py
experimental/nn/ampnet_pipe/pipe.py
+2
-4
experimental/tests/nn/ampnet_pipe_process/test_ampnet_pipe.py
...rimental/tests/nn/ampnet_pipe_process/test_ampnet_pipe.py
+2
-17
fairscale/nn/pipe/__init__.py
fairscale/nn/pipe/__init__.py
+1
-0
fairscale/nn/pipe/async_pipe.py
fairscale/nn/pipe/async_pipe.py
+12
-0
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+0
-3
fairscale/nn/pipe/rpc.py
fairscale/nn/pipe/rpc.py
+2
-2
fairscale/nn/pipe/types.py
fairscale/nn/pipe/types.py
+0
-1
tests/nn/model_parallel/test_layers.py
tests/nn/model_parallel/test_layers.py
+1
-4
tests/nn/pipe_process/skip/test_gpipe.py
tests/nn/pipe_process/skip/test_gpipe.py
+16
-22
tests/nn/pipe_process/skip/test_leak.py
tests/nn/pipe_process/skip/test_leak.py
+6
-8
tests/nn/pipe_process/test_bugs.py
tests/nn/pipe_process/test_bugs.py
+14
-18
tests/nn/pipe_process/test_inplace.py
tests/nn/pipe_process/test_inplace.py
+12
-12
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+127
-186
tests/nn/pipe_process/test_transparency.py
tests/nn/pipe_process/test_transparency.py
+4
-5
No files found.
benchmarks/experimental_ampnet.py
View file @
eaee5976
...
@@ -21,7 +21,7 @@ from torchtext.data.utils import get_tokenizer
...
@@ -21,7 +21,7 @@ from torchtext.data.utils import get_tokenizer
from
experimental.nn.ampnet_pipe
import
pipe
from
experimental.nn.ampnet_pipe
import
pipe
from
fairscale.nn.model_parallel
import
initialize_model_parallel
from
fairscale.nn.model_parallel
import
initialize_model_parallel
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
fairscale.nn.pipe
import
LazyModule
,
MultiProcessPipe
from
fairscale.nn.pipe
import
LazyModule
from
fairscale.optim
import
GradScaler
from
fairscale.optim
import
GradScaler
from
fairscale.utils.testing
import
dist_init
,
get_worker_map
from
fairscale.utils.testing
import
dist_init
,
get_worker_map
...
@@ -420,7 +420,6 @@ def run_mp_worker(args, available_workers):
...
@@ -420,7 +420,6 @@ def run_mp_worker(args, available_workers):
p
=
pipe
.
AMPnetPipe
(
p
=
pipe
.
AMPnetPipe
(
module
=
model
,
module
=
model
,
balance
=
balance
,
balance
=
balance
,
style
=
MultiProcessPipe
.
AsyncSchedule
,
chunks
=
args
.
chunks
,
chunks
=
args
.
chunks
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
...
...
benchmarks/pipe.py
View file @
eaee5976
...
@@ -499,7 +499,6 @@ def run_mp_worker(args, available_workers):
...
@@ -499,7 +499,6 @@ def run_mp_worker(args, available_workers):
pipe_model
=
MultiProcessPipe
(
pipe_model
=
MultiProcessPipe
(
model
,
model
,
balance
,
balance
,
style
=
MultiProcessPipe
.
AsyncSchedule
,
chunks
=
args
.
chunks
,
chunks
=
args
.
chunks
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
...
...
experimental/nn/ampnet_pipe/ampnet.py
View file @
eaee5976
...
@@ -37,7 +37,6 @@ def create_task_without_skip_trackers(
...
@@ -37,7 +37,6 @@ def create_task_without_skip_trackers(
checkpoint_stop
:
int
,
i
:
int
,
j
:
int
,
batch
:
Batch
,
partition
:
nn
.
Sequential
,
checkpoint_stop
:
int
,
i
:
int
,
j
:
int
,
batch
:
Batch
,
partition
:
nn
.
Sequential
,
)
->
Task
:
)
->
Task
:
# Determine whether checkpointing or not.
# Determine whether checkpointing or not.
# style is guaranteed to be PipelineStyle.AsyncSchedule
if
i
<
checkpoint_stop
:
if
i
<
checkpoint_stop
:
def
function
(
def
function
(
...
...
experimental/nn/ampnet_pipe/pipe.py
View file @
eaee5976
...
@@ -11,15 +11,14 @@ from torch import nn
...
@@ -11,15 +11,14 @@ from torch import nn
from
torch.optim.optimizer
import
Optimizer
from
torch.optim.optimizer
import
Optimizer
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
fairscale.nn.pipe
import
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
from
fairscale.nn.pipe.types
import
PipelineStyle
from
.ampnet
import
AsyncAMPnetEventLoop
from
.ampnet
import
AsyncAMPnetEventLoop
__all__
=
[
"AMPnetPipe"
]
__all__
=
[
"AMPnetPipe"
]
class
AMPnetPipe
(
MultiProcess
Pipe
):
class
AMPnetPipe
(
Async
Pipe
):
"""
"""
AMPnetPipe is the asynchronous version of the MultiProcessPipe implementation
AMPnetPipe is the asynchronous version of the MultiProcessPipe implementation
which avoids the bubble issue, by using stale weights and gradients.
which avoids the bubble issue, by using stale weights and gradients.
...
@@ -44,7 +43,6 @@ class AMPnetPipe(MultiProcessPipe):
...
@@ -44,7 +43,6 @@ class AMPnetPipe(MultiProcessPipe):
# AMPnet implementation doesn't handle skip_trackers!
# AMPnet implementation doesn't handle skip_trackers!
assert
self
.
pipeline
.
style
is
PipelineStyle
.
AsyncSchedule
# type: ignore
assert
self
.
group
assert
self
.
group
rank
=
self
.
group
.
rank
()
rank
=
self
.
group
.
rank
()
...
...
experimental/tests/nn/ampnet_pipe_process/test_ampnet_pipe.py
View file @
eaee5976
...
@@ -23,7 +23,6 @@ from torch.optim.optimizer import Optimizer
...
@@ -23,7 +23,6 @@ from torch.optim.optimizer import Optimizer
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data
import
DataLoader
,
Dataset
from
experimental.nn.ampnet_pipe.pipe
import
AMPnetPipe
from
experimental.nn.ampnet_pipe.pipe
import
AMPnetPipe
from
fairscale.nn.pipe
import
MultiProcessPipe
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
...
@@ -84,14 +83,7 @@ class FakeDataset(Dataset):
...
@@ -84,14 +83,7 @@ class FakeDataset(Dataset):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
def
async_event_loop_interleave_simple
():
def
async_event_loop_interleave_simple
():
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(
inplace
=
False
),
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(
inplace
=
False
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(
inplace
=
False
),
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(
inplace
=
False
))
pipe
=
AMPnetPipe
(
pipe
=
AMPnetPipe
(
module
=
model
,
balance
=
[
2
,
2
],
worker_map
=
get_worker_map
(),
chunks
=
10
,
checkpoint
=
"never"
,)
module
=
model
,
balance
=
[
2
,
2
],
style
=
MultiProcessPipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
(),
chunks
=
10
,
checkpoint
=
"never"
,
)
fake_dataset
=
FakeDataset
()
fake_dataset
=
FakeDataset
()
fake_dataloader
=
DataLoader
(
fake_dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
)
fake_dataloader
=
DataLoader
(
fake_dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
)
loss
=
nn
.
MSELoss
()
loss
=
nn
.
MSELoss
()
...
@@ -102,14 +94,7 @@ def async_event_loop_interleave_simple():
...
@@ -102,14 +94,7 @@ def async_event_loop_interleave_simple():
@
torch_spawn
([
4
])
@
torch_spawn
([
4
])
def
async_event_loop_interleave_hard
():
def
async_event_loop_interleave_hard
():
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
))
pipe
=
AMPnetPipe
(
pipe
=
AMPnetPipe
(
module
=
model
,
balance
=
[
1
,
1
,
1
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
10
,
checkpoint
=
"never"
,)
module
=
model
,
balance
=
[
1
,
1
,
1
,
1
],
style
=
MultiProcessPipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
(),
chunks
=
10
,
checkpoint
=
"never"
,
)
fake_dataset
=
FakeDataset
()
fake_dataset
=
FakeDataset
()
fake_dataloader
=
DataLoader
(
fake_dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
)
fake_dataloader
=
DataLoader
(
fake_dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
)
loss
=
nn
.
MSELoss
()
loss
=
nn
.
MSELoss
()
...
...
fairscale/nn/pipe/__init__.py
View file @
eaee5976
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
# limitations under the License.
# limitations under the License.
"""A Pipe implementation in PyTorch."""
"""A Pipe implementation in PyTorch."""
from
.async_pipe
import
AsyncPipe
from
.checkpoint
import
is_checkpointing
,
is_recomputing
from
.checkpoint
import
is_checkpointing
,
is_recomputing
from
.multiprocess_pipe
import
LazyModule
,
MultiProcessPipe
from
.multiprocess_pipe
import
LazyModule
,
MultiProcessPipe
from
.pipe
import
Pipe
from
.pipe
import
Pipe
...
...
fairscale/nn/pipe/async_pipe.py
0 → 100644
View file @
eaee5976
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from
.multiprocess_pipe
import
MultiProcessPipe
from
.types
import
PipelineStyle
class
AsyncPipe
(
MultiProcessPipe
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
# type: ignore
super
().
__init__
(
*
args
,
style
=
PipelineStyle
.
AsyncSchedule
,
**
kwargs
)
fairscale/nn/pipe/multiprocess_pipe.py
View file @
eaee5976
...
@@ -386,9 +386,6 @@ class MultiProcessPipe(Module):
...
@@ -386,9 +386,6 @@ class MultiProcessPipe(Module):
"""
"""
MultiProcess
:
PipelineStyle
=
PipelineStyle
.
MultiProcess
AsyncSchedule
:
PipelineStyle
=
PipelineStyle
.
AsyncSchedule
#: The number of layers in each partition.
#: The number of layers in each partition.
balance
:
List
[
int
]
=
[]
balance
:
List
[
int
]
=
[]
# ^^
# ^^
...
...
fairscale/nn/pipe/rpc.py
View file @
eaee5976
...
@@ -13,6 +13,7 @@ from torch.distributed.distributed_c10d import _get_global_rank
...
@@ -13,6 +13,7 @@ from torch.distributed.distributed_c10d import _get_global_rank
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
.async_pipe
import
AsyncPipe
from
.multiprocess_pipe
import
MultiProcessPipe
from
.multiprocess_pipe
import
MultiProcessPipe
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
TensorOrTensors
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
TensorOrTensors
...
@@ -105,10 +106,9 @@ class PipeRPCWrapper(nn.Module):
...
@@ -105,10 +106,9 @@ class PipeRPCWrapper(nn.Module):
else
:
else
:
kwargs
[
"group"
]
=
self
.
group
kwargs
[
"group"
]
=
self
.
group
kwargs
[
"style"
]
=
MultiProcessPipe
.
AsyncSchedule
kwargs
[
"input_device"
]
=
torch
.
device
(
"cuda"
,
torch
.
cuda
.
current_device
())
kwargs
[
"input_device"
]
=
torch
.
device
(
"cuda"
,
torch
.
cuda
.
current_device
())
self
.
model
=
MultiProcess
Pipe
(
*
args
,
**
kwargs
)
self
.
model
=
Async
Pipe
(
*
args
,
**
kwargs
)
self
.
worker_map
=
kwargs
[
"worker_map"
]
self
.
worker_map
=
kwargs
[
"worker_map"
]
self
.
_foreach_worker
(
self
.
_register_remote_model
,
args
=
(
args
,
kwargs
))
self
.
_foreach_worker
(
self
.
_register_remote_model
,
args
=
(
args
,
kwargs
))
self
.
model
.
cuda
()
self
.
model
.
cuda
()
...
...
fairscale/nn/pipe/types.py
View file @
eaee5976
...
@@ -35,7 +35,6 @@ class LazyModule:
...
@@ -35,7 +35,6 @@ class LazyModule:
class
PipelineStyle
(
Enum
):
class
PipelineStyle
(
Enum
):
SingleProcess
=
auto
()
MultiProcess
=
auto
()
MultiProcess
=
auto
()
AsyncSchedule
=
auto
()
AsyncSchedule
=
auto
()
...
...
tests/nn/model_parallel/test_layers.py
View file @
eaee5976
...
@@ -431,7 +431,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
...
@@ -431,7 +431,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
model
[
2
].
weight
.
data
=
saved_weight_2
model
[
2
].
weight
.
data
=
saved_weight_2
worker_map
=
{
i
:
f
"Test
{
i
}
"
for
i
in
range
(
torch
.
distributed
.
get_world_size
())}
worker_map
=
{
i
:
f
"Test
{
i
}
"
for
i
in
range
(
torch
.
distributed
.
get_world_size
())}
style
=
MultiProcessPipe
.
MultiProcess
# MultiProcessPipe.AsyncSchedule
if
pipe_world_size
==
2
:
if
pipe_world_size
==
2
:
print
(
f
"actually doing pipe stuff now"
)
print
(
f
"actually doing pipe stuff now"
)
...
@@ -440,7 +439,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
...
@@ -440,7 +439,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
pipe_model
=
MultiProcessPipe
(
pipe_model
=
MultiProcessPipe
(
model
,
model
,
[
2
,
1
],
[
2
,
1
],
style
=
style
,
group
=
pipeline_devices
,
group
=
pipeline_devices
,
worker_map
=
worker_map
,
worker_map
=
worker_map
,
input_device
=
torch
.
cuda
.
current_device
(),
input_device
=
torch
.
cuda
.
current_device
(),
...
@@ -507,7 +505,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
...
@@ -507,7 +505,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
failed
=
False
failed
=
False
with
torch
.
autograd
.
profiler
.
profile
()
as
prof
:
with
torch
.
autograd
.
profiler
.
profile
()
as
prof
:
try
:
try
:
if
style
==
MultiProcessPipe
.
MultiProcess
:
pipe_model
.
back_helper
(
pipe_output
)
pipe_model
.
back_helper
(
pipe_output
)
except
Exception
as
e
:
except
Exception
as
e
:
failed
=
True
failed
=
True
...
...
tests/nn/pipe_process/skip/test_gpipe.py
View file @
eaee5976
...
@@ -23,7 +23,7 @@ import pytest
...
@@ -23,7 +23,7 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn.pipe
import
LazyModule
,
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
,
LazyModule
,
MultiProcessPipe
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip.portal
import
PortalBlue
,
PortalCopy
,
PortalOrange
from
fairscale.nn.pipe.skip.portal
import
PortalBlue
,
PortalCopy
,
PortalOrange
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
...
@@ -33,14 +33,14 @@ from fairscale.utils.testing import get_worker_map, torch_spawn
...
@@ -33,14 +33,14 @@ from fairscale.utils.testing import get_worker_map, torch_spawn
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"balance"
,
[[
3
],
[
1
,
2
],
[
2
,
1
],
[
1
,
1
,
1
]],
ids
=
[
"3"
,
"1:2"
,
"2:1"
,
"1:1:1"
])
@
pytest
.
mark
.
parametrize
(
"balance"
,
[[
3
],
[
1
,
2
],
[
2
,
1
],
[
1
,
1
,
1
]],
ids
=
[
"3"
,
"1:2"
,
"2:1"
,
"1:1:1"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
def
x1to3
(
balance
,
checkpoint
,
pipe
line_style
):
def
x1to3
(
balance
,
checkpoint
,
pipe
_class
):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
if
pipe
line_style
==
MultiProcessPipe
.
AsyncSchedul
e
and
len
(
balance
)
>
1
:
if
pipe
_class
==
AsyncPip
e
and
len
(
balance
)
>
1
:
print
(
f
"skipping yarg"
)
print
(
f
"skipping yarg"
)
pytest
.
skip
(
"Skip tensors NYI for Async
Schedul
e"
)
pytest
.
skip
(
"Skip tensors NYI for Async
Pip
e"
)
@
skippable
(
stash
=
[
"1to3"
])
@
skippable
(
stash
=
[
"1to3"
])
class
Layer1
(
nn
.
Module
):
class
Layer1
(
nn
.
Module
):
...
@@ -74,13 +74,12 @@ def x1to3(balance, checkpoint, pipeline_style):
...
@@ -74,13 +74,12 @@ def x1to3(balance, checkpoint, pipeline_style):
return
output
return
output
model
=
nn
.
Sequential
(
Layer1
(),
Layer2
(),
Layer3
())
model
=
nn
.
Sequential
(
Layer1
(),
Layer2
(),
Layer3
())
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
model
,
balance
,
balance
,
chunks
=
3
,
chunks
=
3
,
checkpoint
=
checkpoint
,
checkpoint
=
checkpoint
,
input_device
=
torch
.
cuda
.
current_device
(),
input_device
=
torch
.
cuda
.
current_device
(),
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
pipelined_backward
=
False
,
pipelined_backward
=
False
,
).
cuda
()
).
cuda
()
...
@@ -106,11 +105,11 @@ def x1to3(balance, checkpoint, pipeline_style):
...
@@ -106,11 +105,11 @@ def x1to3(balance, checkpoint, pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
@
pytest
.
mark
.
skip
(
reason
=
"flaky test"
)
@
pytest
.
mark
.
skip
(
reason
=
"flaky test"
)
def
none_skip
(
pipe
line_style
):
def
none_skip
(
pipe
_class
):
if
pipe
line_style
==
MultiProcessPipe
.
AsyncSchedul
e
:
if
pipe
_class
==
AsyncPip
e
:
pytest
.
skip
(
"Skip tensors NYI for Async
Schedul
e"
)
pytest
.
skip
(
"Skip tensors NYI for Async
Pip
e"
)
@
skippable
(
stash
=
[
"none"
])
@
skippable
(
stash
=
[
"none"
])
class
Stash
(
nn
.
Module
):
class
Stash
(
nn
.
Module
):
...
@@ -126,13 +125,8 @@ def none_skip(pipeline_style):
...
@@ -126,13 +125,8 @@ def none_skip(pipeline_style):
return
input
return
input
model
=
nn
.
Sequential
(
Stash
(),
Pop
())
model
=
nn
.
Sequential
(
Stash
(),
Pop
())
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
model
,
[
1
,
1
],
worker_map
=
get_worker_map
(),
input_device
=
torch
.
cuda
.
current_device
(),
chunks
=
5
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
input_device
=
torch
.
cuda
.
current_device
(),
chunks
=
5
,
).
cuda
()
).
cuda
()
input
=
torch
.
rand
(
10
,
requires_grad
=
True
).
cuda
()
input
=
torch
.
rand
(
10
,
requires_grad
=
True
).
cuda
()
...
@@ -161,8 +155,8 @@ def none_skip(pipeline_style):
...
@@ -161,8 +155,8 @@ def none_skip(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
lazy_skippable_error
(
pipe
line_style
):
def
lazy_skippable_error
(
pipe
_class
):
"""Using skippable layers in combination with lazy construction is currently
"""Using skippable layers in combination with lazy construction is currently
not supported, check that it raises an Exception"""
not supported, check that it raises an Exception"""
...
@@ -181,6 +175,6 @@ def lazy_skippable_error(pipeline_style):
...
@@ -181,6 +175,6 @@ def lazy_skippable_error(pipeline_style):
]
]
with
pytest
.
raises
(
ValueError
,
match
=
"Can't use Skippable layers with multi-process pipe and lazy construction"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Can't use Skippable layers with multi-process pipe and lazy construction"
):
MultiProcessPipe
(
pipe_class
(
model
,
[
2
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
model
,
[
2
,
1
],
worker_map
=
get_worker_map
(),
)
)
tests/nn/pipe_process/skip/test_leak.py
View file @
eaee5976
...
@@ -23,7 +23,7 @@ import pytest
...
@@ -23,7 +23,7 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn.pipe
import
MultiProcessPipe
,
is_checkpointing
,
is_recomputing
from
fairscale.nn.pipe
import
AsyncPipe
,
MultiProcessPipe
,
is_checkpointing
,
is_recomputing
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip.tracker
import
current_skip_tracker
from
fairscale.nn.pipe.skip.tracker
import
current_skip_tracker
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
...
@@ -46,10 +46,10 @@ class Pop(nn.Module):
...
@@ -46,10 +46,10 @@ class Pop(nn.Module):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"train"
,
[
True
,
False
],
ids
=
[
"train"
,
"eval"
])
@
pytest
.
mark
.
parametrize
(
"train"
,
[
True
,
False
],
ids
=
[
"train"
,
"eval"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"always"
,
"except_last"
,
"never"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"always"
,
"except_last"
,
"never"
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
def
delete_portal_tensor
(
train
,
checkpoint
,
pipe
line_style
):
def
delete_portal_tensor
(
train
,
checkpoint
,
pipe
_class
):
# Without checkpointing:
# Without checkpointing:
# +- Stash --+ +--- Pop ----+ - - - layers
# +- Stash --+ +--- Pop ----+ - - - layers
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
...
@@ -60,8 +60,8 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
...
@@ -60,8 +60,8 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+
# +----------+ +------------+ +------------+ +----------+
if
pipe
line_style
==
MultiProcessPipe
.
AsyncSchedul
e
:
if
pipe
_class
==
AsyncPip
e
:
pytest
.
skip
(
"Skip tensors NYI for Async
Schedul
e"
)
pytest
.
skip
(
"Skip tensors NYI for Async
Pip
e"
)
def
portal_tensor_life_is
(
tensor_life
,
skip_tracker
=
None
):
def
portal_tensor_life_is
(
tensor_life
,
skip_tracker
=
None
):
if
skip_tracker
is
None
:
if
skip_tracker
is
None
:
...
@@ -114,9 +114,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
...
@@ -114,9 +114,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
return
self
.
F
.
apply
(
input
)
return
self
.
F
.
apply
(
input
)
model
=
nn
.
Sequential
(
NoPortalTensorAtBackward
(),
stash_
,
pop_
)
model
=
nn
.
Sequential
(
NoPortalTensorAtBackward
(),
stash_
,
pop_
)
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
balance
=
[
2
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
checkpoint
,)
model
,
balance
=
[
2
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
checkpoint
,
)
input
=
torch
.
rand
(
10
,
requires_grad
=
True
)
input
=
torch
.
rand
(
10
,
requires_grad
=
True
)
...
...
tests/nn/pipe_process/test_bugs.py
View file @
eaee5976
...
@@ -22,15 +22,15 @@ import torch
...
@@ -22,15 +22,15 @@ import torch
from
torch
import
nn
from
torch
import
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
fairscale.nn.pipe
import
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
,
MultiProcessPipe
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
python_autograd_function
(
pipe
line_style
):
def
python_autograd_function
(
pipe
_class
):
# FIXME deadlock with
MultiProcessPipe.AsyncSchedul
e?
# FIXME deadlock with
AsyncPip
e?
# A Python autograd function might fail with this error:
# A Python autograd function might fail with this error:
#
#
# RuntimeError: Returning Variables sharing storage with other Variables
# RuntimeError: Returning Variables sharing storage with other Variables
...
@@ -57,9 +57,7 @@ def python_autograd_function(pipeline_style):
...
@@ -57,9 +57,7 @@ def python_autograd_function(pipeline_style):
return
Identity
.
apply
(
input
)
return
Identity
.
apply
(
input
)
model
=
nn
.
Sequential
(
M
(),
M
())
model
=
nn
.
Sequential
(
M
(),
M
())
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
[
1
,
1
],
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
).
cuda
()
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
).
cuda
()
model
.
eval
()
model
.
eval
()
x
=
torch
.
rand
(
42
)
x
=
torch
.
rand
(
42
)
...
@@ -73,8 +71,8 @@ def python_autograd_function(pipeline_style):
...
@@ -73,8 +71,8 @@ def python_autograd_function(pipeline_style):
@
torch_spawn
([
3
])
@
torch_spawn
([
3
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
exception_no_hang
(
pipe
line_style
):
def
exception_no_hang
(
pipe
_class
):
# In v0.0.2, once a failed partition receives a normal message
# In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was
# (non-closing) for the next micro-batch, a hang occured. The reason was
# that a failed partition didn't call in_queue.task_done() on a normal
# that a failed partition didn't call in_queue.task_done() on a normal
...
@@ -92,7 +90,7 @@ def exception_no_hang(pipeline_style):
...
@@ -92,7 +90,7 @@ def exception_no_hang(pipeline_style):
raise
ExpectedException
()
raise
ExpectedException
()
model
=
nn
.
Sequential
(
Pass
(),
Pass
(),
Raise
())
model
=
nn
.
Sequential
(
Pass
(),
Pass
(),
Raise
())
model
=
MultiProcessPipe
(
model
,
[
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
3
)
model
=
pipe_class
(
model
,
[
1
,
1
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
3
)
model
.
eval
()
model
.
eval
()
if
model
.
group
.
rank
()
==
2
:
if
model
.
group
.
rank
()
==
2
:
...
@@ -106,8 +104,8 @@ def exception_no_hang(pipeline_style):
...
@@ -106,8 +104,8 @@ def exception_no_hang(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"2 cuda devices required"
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"2 cuda devices required"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
tuple_wait
(
cuda_sleep
,
pipe
line_style
):
def
tuple_wait
(
cuda_sleep
,
pipe
_class
):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility
# Under this behavior, if checkpointing was disabled, there's a possibility
# that gradient accumulations on other tensors are not synchronized
# that gradient accumulations on other tensors are not synchronized
...
@@ -135,10 +133,9 @@ def tuple_wait(cuda_sleep, pipeline_style):
...
@@ -135,10 +133,9 @@ def tuple_wait(cuda_sleep, pipeline_style):
return
a
+
b
+
c
return
a
+
b
+
c
model
=
nn
.
Sequential
(
Layer1
(),
Layer2
())
model
=
nn
.
Sequential
(
Layer1
(),
Layer2
())
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
model
,
[
1
,
1
],
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
input_device
=
torch
.
cuda
.
current_device
(),
input_device
=
torch
.
cuda
.
current_device
(),
chunks
=
32
,
chunks
=
32
,
...
@@ -160,8 +157,8 @@ def tuple_wait(cuda_sleep, pipeline_style):
...
@@ -160,8 +157,8 @@ def tuple_wait(cuda_sleep, pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
parallel_randoms
(
pipe
line_style
):
def
parallel_randoms
(
pipe
_class
):
class
Dropouts
(
nn
.
Module
):
class
Dropouts
(
nn
.
Module
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
for
_
in
range
(
100
):
for
_
in
range
(
100
):
...
@@ -172,10 +169,9 @@ def parallel_randoms(pipeline_style):
...
@@ -172,10 +169,9 @@ def parallel_randoms(pipeline_style):
x
=
torch
.
rand
(
10
,
10
,
requires_grad
=
True
).
cuda
()
x
=
torch
.
rand
(
10
,
10
,
requires_grad
=
True
).
cuda
()
x
.
retain_grad
()
x
.
retain_grad
()
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
model
,
[
1
,
1
],
[
1
,
1
],
style
=
pipeline_style
,
input_device
=
torch
.
cuda
.
current_device
(),
input_device
=
torch
.
cuda
.
current_device
(),
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
chunks
=
10
,
chunks
=
10
,
...
...
tests/nn/pipe_process/test_inplace.py
View file @
eaee5976
...
@@ -21,21 +21,21 @@ import pytest
...
@@ -21,21 +21,21 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn.pipe
import
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
,
MultiProcessPipe
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
inplace_on_requires_grad
(
pipe
line_style
):
def
inplace_on_requires_grad
(
pipe
_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
ReLU
(
inplace
=
True
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
ReLU
(
inplace
=
True
))
model
=
MultiProcessPipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
model
=
pipe_class
(
model
,
[
1
,
1
]
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
x
=
torch
.
rand
(
1
)
x
=
torch
.
rand
(
1
)
if
pipe
line_style
==
MultiProcessPipe
.
AsyncSchedul
e
and
model
.
group
.
rank
()
==
0
:
if
pipe
_class
==
AsyncPip
e
and
model
.
group
.
rank
()
==
0
:
# With Async
Schedul
e, model will wait forever for gradients if not eval
# With Async
Pip
e, model will wait forever for gradients if not eval
model
.
eval
()
model
.
eval
()
y
=
model
(
x
)
y
=
model
(
x
)
...
@@ -50,12 +50,12 @@ def inplace_on_requires_grad(pipeline_style):
...
@@ -50,12 +50,12 @@ def inplace_on_requires_grad(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
inplace_on_not_requires_grad
(
pipe
line_style
):
def
inplace_on_not_requires_grad
(
pipe
_class
):
# In-place operation on a tensor not requiring grad doesn't cause a
# In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case.
# RuntimeError. Currently, we cannot detect this case.
model
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
))
model
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
))
model
=
MultiProcessPipe
(
model
,
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
model
=
pipe_class
(
model
,
[
1
]
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
x
=
torch
.
rand
(
1
)
x
=
torch
.
rand
(
1
)
y
=
model
(
x
)
y
=
model
(
x
)
...
@@ -70,8 +70,8 @@ def inplace_on_not_requires_grad(pipeline_style):
...
@@ -70,8 +70,8 @@ def inplace_on_not_requires_grad(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
inplace_incorrect_grad
(
pipe
line_style
):
def
inplace_incorrect_grad
(
pipe
_class
):
class
M
(
nn
.
Module
):
class
M
(
nn
.
Module
):
def
forward
(
self
,
foo_bar
):
def
forward
(
self
,
foo_bar
):
# 'foo' requires grad but 'bar' does not. In-place operation on
# 'foo' requires grad but 'bar' does not. In-place operation on
...
@@ -88,7 +88,7 @@ def inplace_incorrect_grad(pipeline_style):
...
@@ -88,7 +88,7 @@ def inplace_incorrect_grad(pipeline_style):
return
foo
*
bar
return
foo
*
bar
model
=
nn
.
Sequential
(
M
())
model
=
nn
.
Sequential
(
M
())
model
=
MultiProcessPipe
(
model
,
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
model
=
pipe_class
(
model
,
[
1
]
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
foo
=
torch
.
tensor
([
1.0
],
requires_grad
=
True
)
foo
=
torch
.
tensor
([
1.0
],
requires_grad
=
True
)
bar
=
torch
.
tensor
([
1.0
])
bar
=
torch
.
tensor
([
1.0
])
...
...
tests/nn/pipe_process/test_pipe.py
View file @
eaee5976
...
@@ -31,15 +31,16 @@ from fairscale.nn.model_parallel.initialize import (
...
@@ -31,15 +31,16 @@ from fairscale.nn.model_parallel.initialize import (
get_pipeline_parallel_group
,
get_pipeline_parallel_group
,
initialize_model_parallel
,
initialize_model_parallel
,
)
)
from
fairscale.nn.pipe
import
LazyModule
,
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
,
LazyModule
,
MultiProcessPipe
from
fairscale.nn.pipe.types
import
PipelineStyle
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
,
torch_version
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
,
torch_version
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
parameters
(
pipe
line_style
):
def
parameters
(
pipe
_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
)
pipe
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
)
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
assert
list
(
pipe
.
parameters
())
!=
[]
assert
list
(
pipe
.
parameters
())
!=
[]
else
:
else
:
...
@@ -107,8 +108,8 @@ def mpi():
...
@@ -107,8 +108,8 @@ def mpi():
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
public_attrs
(
pipe
line_style
):
def
public_attrs
(
pipe
_class
):
class
MyString
:
class
MyString
:
def
__init__
(
self
,
value
):
def
__init__
(
self
,
value
):
self
.
value
=
value
self
.
value
=
value
...
@@ -118,14 +119,7 @@ def public_attrs(pipeline_style):
...
@@ -118,14 +119,7 @@ def public_attrs(pipeline_style):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
pipe
=
MultiProcessPipe
(
pipe
=
pipe_class
(
model
,
balance
=
(
1
,),
worker_map
=
get_worker_map
(),
chunks
=
42.000
,
checkpoint
=
MyString
(
"always"
),)
model
,
balance
=
(
1
,),
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
42.000
,
checkpoint
=
MyString
(
"always"
),
)
assert
pipe
.
balance
==
[
1
]
assert
pipe
.
balance
==
[
1
]
assert
pipe
.
chunks
==
42
assert
pipe
.
chunks
==
42
...
@@ -136,13 +130,13 @@ def public_attrs(pipeline_style):
...
@@ -136,13 +130,13 @@ def public_attrs(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"balance"
,
[[
2
],
[
1
,
1
]])
@
pytest
.
mark
.
parametrize
(
"balance"
,
[[
2
],
[
1
,
1
]])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
sequential_like
(
balance
,
pipe
line_style
):
def
sequential_like
(
balance
,
pipe
_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
MultiProcessPipe
(
model
,
balance
,
style
=
pipeline_styl
e
,
worker_map
=
get_worker_map
())
model
=
pipe_class
(
model
,
balanc
e
,
worker_map
=
get_worker_map
())
if
balance
==
[
2
]:
if
balance
==
[
2
]:
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
...
@@ -175,62 +169,62 @@ def sequential_like(balance, pipeline_style):
...
@@ -175,62 +169,62 @@ def sequential_like(balance, pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
balance_wrong_length
(
pipe
line_style
):
def
balance_wrong_length
(
pipe
_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
nn
.
Sequential
(
a
,
b
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
())
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
MultiProcessPipe
(
model
,
balance
=
[
3
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
pipe_class
(
model
,
balance
=
[
3
],
worker_map
=
get_worker_map
())
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
balance_less_than_1
(
pipe
line_style
):
def
balance_less_than_1
(
pipe
_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
nn
.
Sequential
(
a
,
b
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
MultiProcessPipe
(
model
,
balance
=
[
0
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
pipe_class
(
model
,
balance
=
[
0
,
2
],
worker_map
=
get_worker_map
())
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
MultiProcessPipe
(
model
,
balance
=
[
-
1
,
3
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
pipe_class
(
model
,
balance
=
[
-
1
,
3
],
worker_map
=
get_worker_map
())
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
chunks_less_than_1
(
pipe
line_style
):
def
chunks_less_than_1
(
pipe
_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
0
)
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
0
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=-
1
)
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=-
1
)
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
too_few_devices
(
pipe
line_style
):
def
too_few_devices
(
pipe
_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
))
with
pytest
.
raises
(
IndexError
):
with
pytest
.
raises
(
IndexError
):
# len(balance) > len(group.size())
# len(balance) > len(group.size())
model
=
MultiProcessPipe
(
model
,
balance
=
[
1
,
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
pipe_class
(
model
,
balance
=
[
1
,
1
,
1
,
1
],
worker_map
=
get_worker_map
())
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
batch_size_indivisible
(
pipe
line_style
):
def
batch_size_indivisible
(
pipe
_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
4
)
with
pytest
.
warns
(
None
)
as
record
:
with
pytest
.
warns
(
None
)
as
record
:
model
(
torch
.
rand
(
7
,
1
))
model
(
torch
.
rand
(
7
,
1
))
...
@@ -240,10 +234,10 @@ def batch_size_indivisible(pipeline_style):
...
@@ -240,10 +234,10 @@ def batch_size_indivisible(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
batch_size_small
(
pipe
line_style
):
def
batch_size_small
(
pipe
_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
4
)
with
pytest
.
warns
(
None
)
as
record
:
with
pytest
.
warns
(
None
)
as
record
:
model
(
torch
.
rand
(
2
,
1
))
model
(
torch
.
rand
(
2
,
1
))
...
@@ -253,8 +247,8 @@ def batch_size_small(pipeline_style):
...
@@ -253,8 +247,8 @@ def batch_size_small(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
checkpoint_mode
(
pipe
line_style
):
def
checkpoint_mode
(
pipe
_class
):
def
count_grad_fn
(
grad_fn
,
name
,
visited
=
set
()):
def
count_grad_fn
(
grad_fn
,
name
,
visited
=
set
()):
if
grad_fn
in
visited
:
if
grad_fn
in
visited
:
return
0
return
0
...
@@ -273,32 +267,14 @@ def checkpoint_mode(pipeline_style):
...
@@ -273,32 +267,14 @@ def checkpoint_mode(pipeline_style):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
input
=
torch
.
rand
(
2
,
1
)
input
=
torch
.
rand
(
2
,
1
)
always
=
MultiProcessPipe
(
always
=
pipe_class
(
model
,
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"always"
,
pipelined_backward
=
False
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"always"
,
pipelined_backward
=
False
,
)
)
except_last
=
MultiProcessPipe
(
except_last
=
pipe_class
(
model
,
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"except_last"
,
pipelined_backward
=
False
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"except_last"
,
pipelined_backward
=
False
,
)
)
never
=
MultiProcessPipe
(
never
=
pipe_class
(
model
,
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"never"
,
pipelined_backward
=
False
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"never"
,
pipelined_backward
=
False
,
)
)
always_output
=
always
(
input
)
always_output
=
always
(
input
)
...
@@ -311,45 +287,34 @@ def checkpoint_mode(pipeline_style):
...
@@ -311,45 +287,34 @@ def checkpoint_mode(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
checkpoint_mode_invalid
(
pipe
line_style
):
def
checkpoint_mode_invalid
(
pipe
_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
with
pytest
.
raises
(
ValueError
,
match
=
"checkpoint is not one of 'always', 'except_last', or 'never'"
):
with
pytest
.
raises
(
ValueError
,
match
=
"checkpoint is not one of 'always', 'except_last', or 'never'"
):
MultiProcessPipe
(
pipe_class
(
model
,
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"INVALID_CHECKPOINT"
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"INVALID_CHECKPOINT"
,
)
)
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
checkpoint_mode_when_chunks_1
(
pipe
line_style
):
def
checkpoint_mode_when_chunks_1
(
pipe
_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
# All checkpoint modes are fine.
# All checkpoint modes are fine.
MultiProcessPipe
(
pipe_class
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"except_last"
,
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"except_last"
,
)
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"always"
)
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"never"
)
)
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"always"
)
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"never"
)
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
checkpoint_eval
(
pipe
line_style
):
def
checkpoint_eval
(
pipe
_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,)
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
)
input
=
torch
.
rand
(
2
,
1
)
input
=
torch
.
rand
(
2
,
1
)
def
find_grad_fn
(
grad_fn
,
name
):
def
find_grad_fn
(
grad_fn
,
name
):
...
@@ -375,8 +340,8 @@ def checkpoint_eval(pipeline_style):
...
@@ -375,8 +340,8 @@ def checkpoint_eval(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
xfail
(
torch_version
()
<
(
1
,
6
,
0
),
reason
=
"Doesn't work on torch < 1.6.0"
,
strict
=
True
)
@
pytest
.
mark
.
xfail
(
torch_version
()
<
(
1
,
6
,
0
),
reason
=
"Doesn't work on torch < 1.6.0"
,
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
checkpoint_non_float_input
(
pipe
line_style
):
def
checkpoint_non_float_input
(
pipe
_class
):
class
ForkNonFloat
(
nn
.
Module
):
class
ForkNonFloat
(
nn
.
Module
):
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
return
(
input
*
2
,
torch
.
tensor
([
False
]))
return
(
input
*
2
,
torch
.
tensor
([
False
]))
...
@@ -386,14 +351,8 @@ def checkpoint_non_float_input(pipeline_style):
...
@@ -386,14 +351,8 @@ def checkpoint_non_float_input(pipeline_style):
return
input
[
0
]
*
2
return
input
[
0
]
*
2
model
=
nn
.
Sequential
(
ForkNonFloat
(),
JoinNonFloat
())
model
=
nn
.
Sequential
(
ForkNonFloat
(),
JoinNonFloat
())
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
model
,
balance
=
[
1
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"always"
,
pipelined_backward
=
False
,
balance
=
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"always"
,
pipelined_backward
=
False
,
)
)
input
=
torch
.
rand
(
1
,
requires_grad
=
True
)
input
=
torch
.
rand
(
1
,
requires_grad
=
True
)
...
@@ -401,17 +360,17 @@ def checkpoint_non_float_input(pipeline_style):
...
@@ -401,17 +360,17 @@ def checkpoint_non_float_input(pipeline_style):
if
model
.
group
.
rank
()
==
1
:
if
model
.
group
.
rank
()
==
1
:
# with torch.autograd.detect_anomaly():
# with torch.autograd.detect_anomaly():
output
.
backward
()
output
.
backward
()
elif
pipe
line_style
==
MultiProcessPipe
.
MultiProcess
:
elif
pipe
_class
==
MultiProcessPipe
:
model
.
back_helper
(
output
)
model
.
back_helper
(
output
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
no_grad
(
pipe
line_style
):
def
no_grad
(
pipe
_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
)
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
)
input
=
torch
.
rand
(
2
,
1
)
input
=
torch
.
rand
(
2
,
1
)
latent
=
None
latent
=
None
...
@@ -433,8 +392,8 @@ def no_grad(pipeline_style):
...
@@ -433,8 +392,8 @@ def no_grad(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
exception
(
pipe
line_style
):
def
exception
(
pipe
_class
):
class
ExpectedException
(
Exception
):
class
ExpectedException
(
Exception
):
pass
pass
...
@@ -443,7 +402,7 @@ def exception(pipeline_style):
...
@@ -443,7 +402,7 @@ def exception(pipeline_style):
raise
ExpectedException
()
raise
ExpectedException
()
model
=
nn
.
Sequential
(
Raise
())
model
=
nn
.
Sequential
(
Raise
())
model
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
)
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
)
with
pytest
.
raises
(
ExpectedException
):
with
pytest
.
raises
(
ExpectedException
):
model
(
torch
.
rand
(
1
))
model
(
torch
.
rand
(
1
))
...
@@ -453,8 +412,8 @@ def exception(pipeline_style):
...
@@ -453,8 +412,8 @@ def exception(pipeline_style):
@
torch_spawn
([
4
])
@
torch_spawn
([
4
])
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
<
4
,
reason
=
"Not enough GPUs"
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
<
4
,
reason
=
"Not enough GPUs"
)
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
exception_early_stop_asap
(
pipe
line_style
):
def
exception_early_stop_asap
(
pipe
_class
):
"""Even the first partitions have finished to process, the partition before
"""Even the first partitions have finished to process, the partition before
the failed partition hould be killed as soon as possible.
the failed partition hould be killed as soon as possible.
"""
"""
...
@@ -482,7 +441,7 @@ def exception_early_stop_asap(pipeline_style):
...
@@ -482,7 +441,7 @@ def exception_early_stop_asap(pipeline_style):
raise
ExpectedException
()
raise
ExpectedException
()
model
=
nn
.
Sequential
(
Pass
(),
Pass
(),
Counter
(),
Raise
())
model
=
nn
.
Sequential
(
Pass
(),
Pass
(),
Counter
(),
Raise
())
model
=
MultiProcessPipe
(
model
,
[
1
,
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
3
)
model
=
pipe_class
(
model
,
[
1
,
1
,
1
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
3
)
with
pytest
.
raises
(
ExpectedException
):
with
pytest
.
raises
(
ExpectedException
):
model
(
torch
.
rand
(
3
))
model
(
torch
.
rand
(
3
))
...
@@ -492,8 +451,8 @@ def exception_early_stop_asap(pipeline_style):
...
@@ -492,8 +451,8 @@ def exception_early_stop_asap(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
input_pair
(
pipe
line_style
):
def
input_pair
(
pipe
_class
):
class
Two
(
nn
.
Module
):
class
Two
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
@@ -505,9 +464,7 @@ def input_pair(pipeline_style):
...
@@ -505,9 +464,7 @@ def input_pair(pipeline_style):
return
(
self
.
fc_a
(
a
),
self
.
fc_b
(
b
))
return
(
self
.
fc_a
(
a
),
self
.
fc_b
(
b
))
model
=
nn
.
Sequential
(
Two
())
model
=
nn
.
Sequential
(
Two
())
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,)
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
)
a
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
a
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
b
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
b
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
...
@@ -521,8 +478,8 @@ def input_pair(pipeline_style):
...
@@ -521,8 +478,8 @@ def input_pair(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
input_singleton
(
pipe
line_style
):
def
input_singleton
(
pipe
_class
):
class
One
(
nn
.
Module
):
class
One
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
@@ -533,9 +490,7 @@ def input_singleton(pipeline_style):
...
@@ -533,9 +490,7 @@ def input_singleton(pipeline_style):
return
(
self
.
fc
(
a
),)
return
(
self
.
fc
(
a
),)
model
=
nn
.
Sequential
(
One
())
model
=
nn
.
Sequential
(
One
())
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,)
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
)
a
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
a
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
...
@@ -548,10 +503,10 @@ def input_singleton(pipeline_style):
...
@@ -548,10 +503,10 @@ def input_singleton(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
input_varargs
(
pipe
line_style
):
def
input_varargs
(
pipe
_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
())
a
=
torch
.
rand
(
1
)
a
=
torch
.
rand
(
1
)
b
=
torch
.
rand
(
1
)
b
=
torch
.
rand
(
1
)
...
@@ -562,14 +517,14 @@ def input_varargs(pipeline_style):
...
@@ -562,14 +517,14 @@ def input_varargs(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
non_tensor
(
pipe
line_style
):
def
non_tensor
(
pipe
_class
):
class
NonTensor
(
nn
.
Module
):
class
NonTensor
(
nn
.
Module
):
def
forward
(
self
,
_
):
def
forward
(
self
,
_
):
return
"hello"
return
"hello"
model
=
nn
.
Sequential
(
NonTensor
())
model
=
nn
.
Sequential
(
NonTensor
())
model
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
())
x
=
torch
.
rand
(
1
)
x
=
torch
.
rand
(
1
)
# TypeError: expected Tensor as element 0 in argument 0, but got str
# TypeError: expected Tensor as element 0 in argument 0, but got str
...
@@ -582,14 +537,14 @@ def non_tensor(pipeline_style):
...
@@ -582,14 +537,14 @@ def non_tensor(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
non_tensor_tuple
(
pipe
line_style
):
def
non_tensor_tuple
(
pipe
_class
):
class
NonTensorTuple
(
nn
.
Module
):
class
NonTensorTuple
(
nn
.
Module
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
(
x
,
"hello"
)
return
(
x
,
"hello"
)
model
=
nn
.
Sequential
(
NonTensorTuple
())
model
=
nn
.
Sequential
(
NonTensorTuple
())
model
=
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
())
x
=
torch
.
rand
(
1
)
x
=
torch
.
rand
(
1
)
# TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1
# TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1
...
@@ -604,8 +559,8 @@ def non_tensor_tuple(pipeline_style):
...
@@ -604,8 +559,8 @@ def non_tensor_tuple(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
@
pytest
.
mark
.
parametrize
(
"lazy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"lazy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
deferred_batch_norm
(
checkpoint
,
lazy
,
pipe
line_style
):
def
deferred_batch_norm
(
checkpoint
,
lazy
,
pipe
_class
):
bn
=
nn
.
BatchNorm2d
(
3
)
bn
=
nn
.
BatchNorm2d
(
3
)
pipe_bn
=
deepcopy
(
bn
)
pipe_bn
=
deepcopy
(
bn
)
pipe_fn
=
lambda
:
pipe_bn
# noqa: E731
pipe_fn
=
lambda
:
pipe_bn
# noqa: E731
...
@@ -613,14 +568,8 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
...
@@ -613,14 +568,8 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
model
=
[
LazyModule
(
pipe_fn
)]
model
=
[
LazyModule
(
pipe_fn
)]
else
:
else
:
model
=
nn
.
Sequential
(
pipe_bn
)
model
=
nn
.
Sequential
(
pipe_bn
)
pipe
=
MultiProcessPipe
(
pipe
=
pipe_class
(
model
,
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
checkpoint
,
deferred_batch_norm
=
True
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
checkpoint
,
deferred_batch_norm
=
True
,
)
)
x
=
torch
.
rand
(
4
,
3
,
10
,
10
)
x
=
torch
.
rand
(
4
,
3
,
10
,
10
)
...
@@ -634,8 +583,8 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
...
@@ -634,8 +583,8 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
])
@
pytest
.
mark
.
parametrize
(
"lazy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"lazy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
deferred_batch_norm_params
(
checkpoint
,
lazy
,
pipe
line_style
):
def
deferred_batch_norm_params
(
checkpoint
,
lazy
,
pipe
_class
):
bn
=
nn
.
BatchNorm2d
(
3
)
bn
=
nn
.
BatchNorm2d
(
3
)
pipe_bn
=
deepcopy
(
bn
)
pipe_bn
=
deepcopy
(
bn
)
pipe_fn
=
lambda
:
pipe_bn
# noqa: E731
pipe_fn
=
lambda
:
pipe_bn
# noqa: E731
...
@@ -643,14 +592,8 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
...
@@ -643,14 +592,8 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
model
=
[
LazyModule
(
pipe_fn
)]
model
=
[
LazyModule
(
pipe_fn
)]
else
:
else
:
model
=
nn
.
Sequential
(
pipe_bn
)
model
=
nn
.
Sequential
(
pipe_bn
)
pipe
=
MultiProcessPipe
(
pipe
=
pipe_class
(
model
,
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
checkpoint
,
deferred_batch_norm
=
True
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
checkpoint
,
deferred_batch_norm
=
True
,
)
)
x
=
torch
.
rand
(
4
,
3
,
10
,
10
)
x
=
torch
.
rand
(
4
,
3
,
10
,
10
)
...
@@ -665,15 +608,15 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
...
@@ -665,15 +608,15 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
@
torch_spawn
([
4
])
@
torch_spawn
([
4
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
devices
(
pipe
line_style
):
def
devices
(
pipe
_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
c
=
nn
.
Linear
(
1
,
1
)
c
=
nn
.
Linear
(
1
,
1
)
# There are extra two ranks.
# There are extra two ranks.
model
=
nn
.
Sequential
(
a
,
b
,
c
)
model
=
nn
.
Sequential
(
a
,
b
,
c
)
model
=
MultiProcessPipe
(
model
,
[
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
pipe_class
(
model
,
[
1
,
1
,
1
],
worker_map
=
get_worker_map
())
# Extra devices must be discarded.
# Extra devices must be discarded.
if
model
.
group
.
rank
()
==
3
:
if
model
.
group
.
rank
()
==
3
:
...
@@ -681,13 +624,13 @@ def devices(pipeline_style):
...
@@ -681,13 +624,13 @@ def devices(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
partitions
(
pipe
line_style
):
def
partitions
(
pipe
_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
MultiProcessPipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
pipe_class
(
model
,
[
1
,
1
]
,
worker_map
=
get_worker_map
())
assert
isinstance
(
model
.
partitions
,
list
)
assert
isinstance
(
model
.
partitions
,
list
)
assert
len
(
model
)
==
1
assert
len
(
model
)
==
1
...
@@ -701,13 +644,13 @@ def partitions(pipeline_style):
...
@@ -701,13 +644,13 @@ def partitions(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
deny_moving
(
pipe
line_style
):
def
deny_moving
(
pipe
_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
MultiProcessPipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
pipe_class
(
model
,
[
1
,
1
]
,
worker_map
=
get_worker_map
())
model
.
cuda
()
model
.
cuda
()
model
.
cpu
()
model
.
cpu
()
...
@@ -725,11 +668,11 @@ def deny_moving(pipeline_style):
...
@@ -725,11 +668,11 @@ def deny_moving(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
empty_module
(
pipe
line_style
):
def
empty_module
(
pipe
_class
):
# Empty sequential module is not illegal.
# Empty sequential module is not illegal.
model
=
nn
.
Sequential
()
model
=
nn
.
Sequential
()
model
=
MultiProcessPipe
(
model
,
[],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
pipe_class
(
model
,
[]
,
worker_map
=
get_worker_map
())
assert
model
(
torch
.
tensor
([
42
]))
==
torch
.
tensor
([
42
])
assert
model
(
torch
.
tensor
([
42
]))
==
torch
.
tensor
([
42
])
assert
model
((
torch
.
tensor
([
42
]),))
==
(
torch
.
tensor
([
42
]),)
assert
model
((
torch
.
tensor
([
42
]),))
==
(
torch
.
tensor
([
42
]),)
...
@@ -741,13 +684,13 @@ def empty_module(pipeline_style):
...
@@ -741,13 +684,13 @@ def empty_module(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
named_children
(
pipe
line_style
):
def
named_children
(
pipe
_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
model
=
nn
.
Sequential
(
OrderedDict
([(
"a"
,
a
),
(
"b"
,
b
)]))
model
=
nn
.
Sequential
(
OrderedDict
([(
"a"
,
a
),
(
"b"
,
b
)]))
model
=
MultiProcessPipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
pipe_class
(
model
,
[
1
,
1
]
,
worker_map
=
get_worker_map
())
names
=
set
(
n
for
n
,
_
in
model
.
named_modules
())
names
=
set
(
n
for
n
,
_
in
model
.
named_modules
())
if
model
.
group
.
rank
()
==
0
:
if
model
.
group
.
rank
()
==
0
:
...
@@ -762,24 +705,24 @@ def named_children(pipeline_style):
...
@@ -762,24 +705,24 @@ def named_children(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
recommend_auto_balance
(
pipe
line_style
):
def
recommend_auto_balance
(
pipe
_class
):
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
# balance is required
# balance is required
MultiProcessPipe
(
nn
.
Sequential
())
pipe_class
(
nn
.
Sequential
())
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
# module and sum of balance have differen length (module: 0, sum of balance: 1)
# module and sum of balance have differen length (module: 0, sum of balance: 1)
MultiProcessPipe
(
nn
.
Sequential
(),
[
1
])
pipe_class
(
nn
.
Sequential
(),
[
1
])
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
# module and sum of balance have different length (module: 2, sum of balance: 1)
# module and sum of balance have different length (module: 2, sum of balance: 1)
MultiProcessPipe
(
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
)),
[
1
])
pipe_class
(
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
)),
[
1
])
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
lazy_construction
(
pipe
line_style
):
def
lazy_construction
(
pipe
_class
):
init_count
=
0
init_count
=
0
class
Custom
(
nn
.
Module
):
class
Custom
(
nn
.
Module
):
...
@@ -798,7 +741,7 @@ def lazy_construction(pipeline_style):
...
@@ -798,7 +741,7 @@ def lazy_construction(pipeline_style):
LazyModule
(
lambda
:
Custom
()),
LazyModule
(
lambda
:
Custom
()),
]
]
pipe
=
MultiProcessPipe
(
model
,
balance
=
[
2
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
pipe
=
pipe_class
(
model
,
balance
=
[
2
,
2
],
worker_map
=
get_worker_map
())
assert
isinstance
(
pipe
[
0
],
Custom
)
assert
isinstance
(
pipe
[
0
],
Custom
)
assert
isinstance
(
pipe
[
1
],
Custom
)
assert
isinstance
(
pipe
[
1
],
Custom
)
...
@@ -808,18 +751,18 @@ def lazy_construction(pipeline_style):
...
@@ -808,18 +751,18 @@ def lazy_construction(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"doesn't apply to mpi"
)
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"doesn't apply to mpi"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
missing_worker_map
(
pipe
line_style
):
def
missing_worker_map
(
pipe
_class
):
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
with
pytest
.
raises
(
ValueError
,
match
=
"'RpcTransport' requires 'worker_map' to be set"
):
with
pytest
.
raises
(
ValueError
,
match
=
"'RpcTransport' requires 'worker_map' to be set"
):
MultiProcessPipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
)
pipe_class
(
model
,
[
1
,
1
]
)
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skip
(
reason
=
"currently broken"
)
@
pytest
.
mark
.
skip
(
reason
=
"currently broken"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
verify_module_duplicate_parameters_on_distinct_partitions
(
pipe
line_style
):
def
verify_module_duplicate_parameters_on_distinct_partitions
(
pipe
_class
):
class
Surrogate
(
nn
.
Module
):
class
Surrogate
(
nn
.
Module
):
def
__init__
(
self
,
module
):
def
__init__
(
self
,
module
):
super
().
__init__
()
super
().
__init__
()
...
@@ -830,23 +773,23 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style):
...
@@ -830,23 +773,23 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style):
# FIXME(tom) can't have duplicate params with separate processes
# FIXME(tom) can't have duplicate params with separate processes
with
pytest
.
raises
(
ValueError
,
match
=
"module with duplicate parameters on distinct devices is not supported"
):
with
pytest
.
raises
(
ValueError
,
match
=
"module with duplicate parameters on distinct devices is not supported"
):
MultiProcessPipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
pipe_class
(
model
,
[
1
,
1
]
,
worker_map
=
get_worker_map
())
@
torch_spawn
([
4
])
@
torch_spawn
([
4
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
pipelined_backward
(
pipe
line_style
):
def
pipelined_backward
(
pipe
_class
):
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
destroy_model_parallel
()
destroy_model_parallel
()
initialize_model_parallel
(
1
,
4
)
initialize_model_parallel
(
1
,
4
)
pipe
=
MultiProcessPipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
pipe
=
pipe_class
(
model
,
[
1
,
1
]
,
worker_map
=
get_worker_map
())
assert
pipe
.
pipelined_backward
is
False
assert
pipe
.
pipelined_backward
is
False
destroy_model_parallel
()
destroy_model_parallel
()
initialize_model_parallel
(
2
,
2
)
initialize_model_parallel
(
2
,
2
)
pipe
=
MultiProcessPipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
pipe
=
pipe_class
(
model
,
[
1
,
1
]
,
worker_map
=
get_worker_map
())
assert
pipe
.
pipelined_backward
is
True
assert
pipe
.
pipelined_backward
is
True
...
@@ -855,9 +798,7 @@ def pipelined_backward(pipeline_style):
...
@@ -855,9 +798,7 @@ def pipelined_backward(pipeline_style):
def
async_event_loop
():
def
async_event_loop
():
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
())
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
())
pipe
=
MultiProcessPipe
(
pipe
=
AsyncPipe
(
model
,
[
1
,
1
,
1
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
10
)
model
,
[
1
,
1
,
1
,
1
],
style
=
MultiProcessPipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
(),
chunks
=
10
)
inputs
=
torch
.
rand
(
100
,
10
)
inputs
=
torch
.
rand
(
100
,
10
)
...
@@ -873,7 +814,7 @@ def reuse_lazy():
...
@@ -873,7 +814,7 @@ def reuse_lazy():
reused
=
LazyModule
(
lambda
:
nn
.
Linear
(
10
,
10
))
reused
=
LazyModule
(
lambda
:
nn
.
Linear
(
10
,
10
))
model
=
[
reused
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
nn
.
ReLU
(),
reused
,
nn
.
ReLU
()]
model
=
[
reused
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
nn
.
ReLU
(),
reused
,
nn
.
ReLU
()]
# model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()]
# model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()]
pipe
=
MultiProcess
Pipe
(
model
,
[
3
,
1
,
1
],
style
=
MultiProcessPipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
())
pipe
=
Async
Pipe
(
model
,
[
3
,
1
,
1
],
worker_map
=
get_worker_map
())
pipe
.
eval
()
pipe
.
eval
()
output
=
pipe
(
torch
.
rand
(
10
))
output
=
pipe
(
torch
.
rand
(
10
))
...
@@ -891,7 +832,7 @@ def reuse_lazy():
...
@@ -891,7 +832,7 @@ def reuse_lazy():
# ensure identical weights but no sharing between model and pipe
# ensure identical weights but no sharing between model and pipe
reused
=
nn
.
Linear
(
10
,
10
)
reused
=
nn
.
Linear
(
10
,
10
)
layers
=
[
reused
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
nn
.
ReLU
(),
reused
,
nn
.
ReLU
()]
layers
=
[
reused
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
nn
.
ReLU
(),
reused
,
nn
.
ReLU
()]
pipe
=
MultiProcess
Pipe
(
layers
,
[
3
,
1
,
1
],
style
=
MultiProcessPipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
())
pipe
=
Async
Pipe
(
layers
,
[
3
,
1
,
1
],
worker_map
=
get_worker_map
())
pipe
.
eval
()
pipe
.
eval
()
model_optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
)
model_optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
)
pipe_optimizer
=
torch
.
optim
.
SGD
(
pipe
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
)
if
len
(
list
(
pipe
.
parameters
()))
else
None
pipe_optimizer
=
torch
.
optim
.
SGD
(
pipe
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
)
if
len
(
list
(
pipe
.
parameters
()))
else
None
...
@@ -964,7 +905,7 @@ def test_instantiate_partition():
...
@@ -964,7 +905,7 @@ def test_instantiate_partition():
# instantiated model
# instantiated model
for
rank
in
range
(
len
(
balance
)):
for
rank
in
range
(
len
(
balance
)):
instantiated
=
instantiate_partition
(
instantiated
=
instantiate_partition
(
model
,
balance
,
FakeGroup
(
rank
,
len
(
balance
)),
MultiProcessPip
e
.
AsyncSchedule
model
,
balance
,
FakeGroup
(
rank
,
len
(
balance
)),
PipelineStyl
e
.
AsyncSchedule
)
)
for
part
in
instantiated
:
for
part
in
instantiated
:
assert
isinstance
(
part
.
module
,
nn
.
Sequential
)
assert
isinstance
(
part
.
module
,
nn
.
Sequential
)
...
...
tests/nn/pipe_process/test_transparency.py
View file @
eaee5976
...
@@ -21,14 +21,14 @@ import pytest
...
@@ -21,14 +21,14 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn.pipe
import
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
,
MultiProcessPipe
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
simple_linears
(
pipe
line_style
):
def
simple_linears
(
pipe
_class
):
def
sum_grad
(
parameters
):
def
sum_grad
(
parameters
):
return
sum
([
p
.
grad
.
sum
()
for
p
in
parameters
if
p
.
grad
is
not
None
])
return
sum
([
p
.
grad
.
sum
()
for
p
in
parameters
if
p
.
grad
is
not
None
])
...
@@ -54,8 +54,7 @@ def simple_linears(pipeline_style):
...
@@ -54,8 +54,7 @@ def simple_linears(pipeline_style):
zero_grad
(
model
.
parameters
())
zero_grad
(
model
.
parameters
())
# With MultiProcessPipe
model
=
pipe_class
(
model
,
[
2
,
2
],
worker_map
=
get_worker_map
(),
chunks
=
4
)
model
=
MultiProcessPipe
(
model
,
[
2
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
outputs
=
model
(
inputs
)
outputs
=
model
(
inputs
)
if
model
.
group
.
rank
()
==
1
:
if
model
.
group
.
rank
()
==
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment