Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
eaee5976
Unverified
Commit
eaee5976
authored
Jan 29, 2021
by
msbaines
Committed by
GitHub
Jan 29, 2021
Browse files
[refactor] make AsyncPipe its own class (#341)
parent
51625eda
Changes
17
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
200 additions
and
286 deletions
+200
-286
benchmarks/experimental_ampnet.py
benchmarks/experimental_ampnet.py
+1
-2
benchmarks/pipe.py
benchmarks/pipe.py
+0
-1
experimental/nn/ampnet_pipe/ampnet.py
experimental/nn/ampnet_pipe/ampnet.py
+0
-1
experimental/nn/ampnet_pipe/pipe.py
experimental/nn/ampnet_pipe/pipe.py
+2
-4
experimental/tests/nn/ampnet_pipe_process/test_ampnet_pipe.py
...rimental/tests/nn/ampnet_pipe_process/test_ampnet_pipe.py
+2
-17
fairscale/nn/pipe/__init__.py
fairscale/nn/pipe/__init__.py
+1
-0
fairscale/nn/pipe/async_pipe.py
fairscale/nn/pipe/async_pipe.py
+12
-0
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+0
-3
fairscale/nn/pipe/rpc.py
fairscale/nn/pipe/rpc.py
+2
-2
fairscale/nn/pipe/types.py
fairscale/nn/pipe/types.py
+0
-1
tests/nn/model_parallel/test_layers.py
tests/nn/model_parallel/test_layers.py
+1
-4
tests/nn/pipe_process/skip/test_gpipe.py
tests/nn/pipe_process/skip/test_gpipe.py
+16
-22
tests/nn/pipe_process/skip/test_leak.py
tests/nn/pipe_process/skip/test_leak.py
+6
-8
tests/nn/pipe_process/test_bugs.py
tests/nn/pipe_process/test_bugs.py
+14
-18
tests/nn/pipe_process/test_inplace.py
tests/nn/pipe_process/test_inplace.py
+12
-12
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+127
-186
tests/nn/pipe_process/test_transparency.py
tests/nn/pipe_process/test_transparency.py
+4
-5
No files found.
benchmarks/experimental_ampnet.py
View file @
eaee5976
...
@@ -21,7 +21,7 @@ from torchtext.data.utils import get_tokenizer
...
@@ -21,7 +21,7 @@ from torchtext.data.utils import get_tokenizer
from
experimental.nn.ampnet_pipe
import
pipe
from
experimental.nn.ampnet_pipe
import
pipe
from
fairscale.nn.model_parallel
import
initialize_model_parallel
from
fairscale.nn.model_parallel
import
initialize_model_parallel
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
fairscale.nn.pipe
import
LazyModule
,
MultiProcessPipe
from
fairscale.nn.pipe
import
LazyModule
from
fairscale.optim
import
GradScaler
from
fairscale.optim
import
GradScaler
from
fairscale.utils.testing
import
dist_init
,
get_worker_map
from
fairscale.utils.testing
import
dist_init
,
get_worker_map
...
@@ -420,7 +420,6 @@ def run_mp_worker(args, available_workers):
...
@@ -420,7 +420,6 @@ def run_mp_worker(args, available_workers):
p
=
pipe
.
AMPnetPipe
(
p
=
pipe
.
AMPnetPipe
(
module
=
model
,
module
=
model
,
balance
=
balance
,
balance
=
balance
,
style
=
MultiProcessPipe
.
AsyncSchedule
,
chunks
=
args
.
chunks
,
chunks
=
args
.
chunks
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
...
...
benchmarks/pipe.py
View file @
eaee5976
...
@@ -499,7 +499,6 @@ def run_mp_worker(args, available_workers):
...
@@ -499,7 +499,6 @@ def run_mp_worker(args, available_workers):
pipe_model
=
MultiProcessPipe
(
pipe_model
=
MultiProcessPipe
(
model
,
model
,
balance
,
balance
,
style
=
MultiProcessPipe
.
AsyncSchedule
,
chunks
=
args
.
chunks
,
chunks
=
args
.
chunks
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
...
...
experimental/nn/ampnet_pipe/ampnet.py
View file @
eaee5976
...
@@ -37,7 +37,6 @@ def create_task_without_skip_trackers(
...
@@ -37,7 +37,6 @@ def create_task_without_skip_trackers(
checkpoint_stop
:
int
,
i
:
int
,
j
:
int
,
batch
:
Batch
,
partition
:
nn
.
Sequential
,
checkpoint_stop
:
int
,
i
:
int
,
j
:
int
,
batch
:
Batch
,
partition
:
nn
.
Sequential
,
)
->
Task
:
)
->
Task
:
# Determine whether checkpointing or not.
# Determine whether checkpointing or not.
# style is guaranteed to be PipelineStyle.AsyncSchedule
if
i
<
checkpoint_stop
:
if
i
<
checkpoint_stop
:
def
function
(
def
function
(
...
...
experimental/nn/ampnet_pipe/pipe.py
View file @
eaee5976
...
@@ -11,15 +11,14 @@ from torch import nn
...
@@ -11,15 +11,14 @@ from torch import nn
from
torch.optim.optimizer
import
Optimizer
from
torch.optim.optimizer
import
Optimizer
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
fairscale.nn.pipe
import
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
from
fairscale.nn.pipe.types
import
PipelineStyle
from
.ampnet
import
AsyncAMPnetEventLoop
from
.ampnet
import
AsyncAMPnetEventLoop
__all__
=
[
"AMPnetPipe"
]
__all__
=
[
"AMPnetPipe"
]
class
AMPnetPipe
(
MultiProcess
Pipe
):
class
AMPnetPipe
(
Async
Pipe
):
"""
"""
AMPnetPipe is the asynchronous version of the MultiProcessPipe implementation
AMPnetPipe is the asynchronous version of the MultiProcessPipe implementation
which avoids the bubble issue, by using stale weights and gradients.
which avoids the bubble issue, by using stale weights and gradients.
...
@@ -44,7 +43,6 @@ class AMPnetPipe(MultiProcessPipe):
...
@@ -44,7 +43,6 @@ class AMPnetPipe(MultiProcessPipe):
# AMPnet implementation doesn't handle skip_trackers!
# AMPnet implementation doesn't handle skip_trackers!
assert
self
.
pipeline
.
style
is
PipelineStyle
.
AsyncSchedule
# type: ignore
assert
self
.
group
assert
self
.
group
rank
=
self
.
group
.
rank
()
rank
=
self
.
group
.
rank
()
...
...
experimental/tests/nn/ampnet_pipe_process/test_ampnet_pipe.py
View file @
eaee5976
...
@@ -23,7 +23,6 @@ from torch.optim.optimizer import Optimizer
...
@@ -23,7 +23,6 @@ from torch.optim.optimizer import Optimizer
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data
import
DataLoader
,
Dataset
from
experimental.nn.ampnet_pipe.pipe
import
AMPnetPipe
from
experimental.nn.ampnet_pipe.pipe
import
AMPnetPipe
from
fairscale.nn.pipe
import
MultiProcessPipe
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
...
@@ -84,14 +83,7 @@ class FakeDataset(Dataset):
...
@@ -84,14 +83,7 @@ class FakeDataset(Dataset):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
def
async_event_loop_interleave_simple
():
def
async_event_loop_interleave_simple
():
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(
inplace
=
False
),
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(
inplace
=
False
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(
inplace
=
False
),
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(
inplace
=
False
))
pipe
=
AMPnetPipe
(
pipe
=
AMPnetPipe
(
module
=
model
,
balance
=
[
2
,
2
],
worker_map
=
get_worker_map
(),
chunks
=
10
,
checkpoint
=
"never"
,)
module
=
model
,
balance
=
[
2
,
2
],
style
=
MultiProcessPipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
(),
chunks
=
10
,
checkpoint
=
"never"
,
)
fake_dataset
=
FakeDataset
()
fake_dataset
=
FakeDataset
()
fake_dataloader
=
DataLoader
(
fake_dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
)
fake_dataloader
=
DataLoader
(
fake_dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
)
loss
=
nn
.
MSELoss
()
loss
=
nn
.
MSELoss
()
...
@@ -102,14 +94,7 @@ def async_event_loop_interleave_simple():
...
@@ -102,14 +94,7 @@ def async_event_loop_interleave_simple():
@
torch_spawn
([
4
])
@
torch_spawn
([
4
])
def
async_event_loop_interleave_hard
():
def
async_event_loop_interleave_hard
():
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
))
pipe
=
AMPnetPipe
(
pipe
=
AMPnetPipe
(
module
=
model
,
balance
=
[
1
,
1
,
1
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
10
,
checkpoint
=
"never"
,)
module
=
model
,
balance
=
[
1
,
1
,
1
,
1
],
style
=
MultiProcessPipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
(),
chunks
=
10
,
checkpoint
=
"never"
,
)
fake_dataset
=
FakeDataset
()
fake_dataset
=
FakeDataset
()
fake_dataloader
=
DataLoader
(
fake_dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
)
fake_dataloader
=
DataLoader
(
fake_dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
)
loss
=
nn
.
MSELoss
()
loss
=
nn
.
MSELoss
()
...
...
fairscale/nn/pipe/__init__.py
View file @
eaee5976
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
# limitations under the License.
# limitations under the License.
"""A Pipe implementation in PyTorch."""
"""A Pipe implementation in PyTorch."""
from
.async_pipe
import
AsyncPipe
from
.checkpoint
import
is_checkpointing
,
is_recomputing
from
.checkpoint
import
is_checkpointing
,
is_recomputing
from
.multiprocess_pipe
import
LazyModule
,
MultiProcessPipe
from
.multiprocess_pipe
import
LazyModule
,
MultiProcessPipe
from
.pipe
import
Pipe
from
.pipe
import
Pipe
...
...
fairscale/nn/pipe/async_pipe.py
0 → 100644
View file @
eaee5976
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from
.multiprocess_pipe
import
MultiProcessPipe
from
.types
import
PipelineStyle
class
AsyncPipe
(
MultiProcessPipe
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
# type: ignore
super
().
__init__
(
*
args
,
style
=
PipelineStyle
.
AsyncSchedule
,
**
kwargs
)
fairscale/nn/pipe/multiprocess_pipe.py
View file @
eaee5976
...
@@ -386,9 +386,6 @@ class MultiProcessPipe(Module):
...
@@ -386,9 +386,6 @@ class MultiProcessPipe(Module):
"""
"""
MultiProcess
:
PipelineStyle
=
PipelineStyle
.
MultiProcess
AsyncSchedule
:
PipelineStyle
=
PipelineStyle
.
AsyncSchedule
#: The number of layers in each partition.
#: The number of layers in each partition.
balance
:
List
[
int
]
=
[]
balance
:
List
[
int
]
=
[]
# ^^
# ^^
...
...
fairscale/nn/pipe/rpc.py
View file @
eaee5976
...
@@ -13,6 +13,7 @@ from torch.distributed.distributed_c10d import _get_global_rank
...
@@ -13,6 +13,7 @@ from torch.distributed.distributed_c10d import _get_global_rank
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
.async_pipe
import
AsyncPipe
from
.multiprocess_pipe
import
MultiProcessPipe
from
.multiprocess_pipe
import
MultiProcessPipe
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
TensorOrTensors
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
TensorOrTensors
...
@@ -105,10 +106,9 @@ class PipeRPCWrapper(nn.Module):
...
@@ -105,10 +106,9 @@ class PipeRPCWrapper(nn.Module):
else
:
else
:
kwargs
[
"group"
]
=
self
.
group
kwargs
[
"group"
]
=
self
.
group
kwargs
[
"style"
]
=
MultiProcessPipe
.
AsyncSchedule
kwargs
[
"input_device"
]
=
torch
.
device
(
"cuda"
,
torch
.
cuda
.
current_device
())
kwargs
[
"input_device"
]
=
torch
.
device
(
"cuda"
,
torch
.
cuda
.
current_device
())
self
.
model
=
MultiProcess
Pipe
(
*
args
,
**
kwargs
)
self
.
model
=
Async
Pipe
(
*
args
,
**
kwargs
)
self
.
worker_map
=
kwargs
[
"worker_map"
]
self
.
worker_map
=
kwargs
[
"worker_map"
]
self
.
_foreach_worker
(
self
.
_register_remote_model
,
args
=
(
args
,
kwargs
))
self
.
_foreach_worker
(
self
.
_register_remote_model
,
args
=
(
args
,
kwargs
))
self
.
model
.
cuda
()
self
.
model
.
cuda
()
...
...
fairscale/nn/pipe/types.py
View file @
eaee5976
...
@@ -35,7 +35,6 @@ class LazyModule:
...
@@ -35,7 +35,6 @@ class LazyModule:
class
PipelineStyle
(
Enum
):
class
PipelineStyle
(
Enum
):
SingleProcess
=
auto
()
MultiProcess
=
auto
()
MultiProcess
=
auto
()
AsyncSchedule
=
auto
()
AsyncSchedule
=
auto
()
...
...
tests/nn/model_parallel/test_layers.py
View file @
eaee5976
...
@@ -431,7 +431,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
...
@@ -431,7 +431,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
model
[
2
].
weight
.
data
=
saved_weight_2
model
[
2
].
weight
.
data
=
saved_weight_2
worker_map
=
{
i
:
f
"Test
{
i
}
"
for
i
in
range
(
torch
.
distributed
.
get_world_size
())}
worker_map
=
{
i
:
f
"Test
{
i
}
"
for
i
in
range
(
torch
.
distributed
.
get_world_size
())}
style
=
MultiProcessPipe
.
MultiProcess
# MultiProcessPipe.AsyncSchedule
if
pipe_world_size
==
2
:
if
pipe_world_size
==
2
:
print
(
f
"actually doing pipe stuff now"
)
print
(
f
"actually doing pipe stuff now"
)
...
@@ -440,7 +439,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
...
@@ -440,7 +439,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
pipe_model
=
MultiProcessPipe
(
pipe_model
=
MultiProcessPipe
(
model
,
model
,
[
2
,
1
],
[
2
,
1
],
style
=
style
,
group
=
pipeline_devices
,
group
=
pipeline_devices
,
worker_map
=
worker_map
,
worker_map
=
worker_map
,
input_device
=
torch
.
cuda
.
current_device
(),
input_device
=
torch
.
cuda
.
current_device
(),
...
@@ -507,8 +505,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
...
@@ -507,8 +505,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
failed
=
False
failed
=
False
with
torch
.
autograd
.
profiler
.
profile
()
as
prof
:
with
torch
.
autograd
.
profiler
.
profile
()
as
prof
:
try
:
try
:
if
style
==
MultiProcessPipe
.
MultiProcess
:
pipe_model
.
back_helper
(
pipe_output
)
pipe_model
.
back_helper
(
pipe_output
)
except
Exception
as
e
:
except
Exception
as
e
:
failed
=
True
failed
=
True
print
(
f
"got
{
e
}
while doing backward, deadlock?"
)
print
(
f
"got
{
e
}
while doing backward, deadlock?"
)
...
...
tests/nn/pipe_process/skip/test_gpipe.py
View file @
eaee5976
...
@@ -23,7 +23,7 @@ import pytest
...
@@ -23,7 +23,7 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn.pipe
import
LazyModule
,
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
,
LazyModule
,
MultiProcessPipe
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip.portal
import
PortalBlue
,
PortalCopy
,
PortalOrange
from
fairscale.nn.pipe.skip.portal
import
PortalBlue
,
PortalCopy
,
PortalOrange
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
...
@@ -33,14 +33,14 @@ from fairscale.utils.testing import get_worker_map, torch_spawn
...
@@ -33,14 +33,14 @@ from fairscale.utils.testing import get_worker_map, torch_spawn
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"balance"
,
[[
3
],
[
1
,
2
],
[
2
,
1
],
[
1
,
1
,
1
]],
ids
=
[
"3"
,
"1:2"
,
"2:1"
,
"1:1:1"
])
@
pytest
.
mark
.
parametrize
(
"balance"
,
[[
3
],
[
1
,
2
],
[
2
,
1
],
[
1
,
1
,
1
]],
ids
=
[
"3"
,
"1:2"
,
"2:1"
,
"1:1:1"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
def
x1to3
(
balance
,
checkpoint
,
pipe
line_style
):
def
x1to3
(
balance
,
checkpoint
,
pipe
_class
):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
if
pipe
line_style
==
MultiProcessPipe
.
AsyncSchedul
e
and
len
(
balance
)
>
1
:
if
pipe
_class
==
AsyncPip
e
and
len
(
balance
)
>
1
:
print
(
f
"skipping yarg"
)
print
(
f
"skipping yarg"
)
pytest
.
skip
(
"Skip tensors NYI for Async
Schedul
e"
)
pytest
.
skip
(
"Skip tensors NYI for Async
Pip
e"
)
@
skippable
(
stash
=
[
"1to3"
])
@
skippable
(
stash
=
[
"1to3"
])
class
Layer1
(
nn
.
Module
):
class
Layer1
(
nn
.
Module
):
...
@@ -74,13 +74,12 @@ def x1to3(balance, checkpoint, pipeline_style):
...
@@ -74,13 +74,12 @@ def x1to3(balance, checkpoint, pipeline_style):
return
output
return
output
model
=
nn
.
Sequential
(
Layer1
(),
Layer2
(),
Layer3
())
model
=
nn
.
Sequential
(
Layer1
(),
Layer2
(),
Layer3
())
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
model
,
balance
,
balance
,
chunks
=
3
,
chunks
=
3
,
checkpoint
=
checkpoint
,
checkpoint
=
checkpoint
,
input_device
=
torch
.
cuda
.
current_device
(),
input_device
=
torch
.
cuda
.
current_device
(),
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
pipelined_backward
=
False
,
pipelined_backward
=
False
,
).
cuda
()
).
cuda
()
...
@@ -106,11 +105,11 @@ def x1to3(balance, checkpoint, pipeline_style):
...
@@ -106,11 +105,11 @@ def x1to3(balance, checkpoint, pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
@
pytest
.
mark
.
skip
(
reason
=
"flaky test"
)
@
pytest
.
mark
.
skip
(
reason
=
"flaky test"
)
def
none_skip
(
pipe
line_style
):
def
none_skip
(
pipe
_class
):
if
pipe
line_style
==
MultiProcessPipe
.
AsyncSchedul
e
:
if
pipe
_class
==
AsyncPip
e
:
pytest
.
skip
(
"Skip tensors NYI for Async
Schedul
e"
)
pytest
.
skip
(
"Skip tensors NYI for Async
Pip
e"
)
@
skippable
(
stash
=
[
"none"
])
@
skippable
(
stash
=
[
"none"
])
class
Stash
(
nn
.
Module
):
class
Stash
(
nn
.
Module
):
...
@@ -126,13 +125,8 @@ def none_skip(pipeline_style):
...
@@ -126,13 +125,8 @@ def none_skip(pipeline_style):
return
input
return
input
model
=
nn
.
Sequential
(
Stash
(),
Pop
())
model
=
nn
.
Sequential
(
Stash
(),
Pop
())
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
model
,
[
1
,
1
],
worker_map
=
get_worker_map
(),
input_device
=
torch
.
cuda
.
current_device
(),
chunks
=
5
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
input_device
=
torch
.
cuda
.
current_device
(),
chunks
=
5
,
).
cuda
()
).
cuda
()
input
=
torch
.
rand
(
10
,
requires_grad
=
True
).
cuda
()
input
=
torch
.
rand
(
10
,
requires_grad
=
True
).
cuda
()
...
@@ -161,8 +155,8 @@ def none_skip(pipeline_style):
...
@@ -161,8 +155,8 @@ def none_skip(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
lazy_skippable_error
(
pipe
line_style
):
def
lazy_skippable_error
(
pipe
_class
):
"""Using skippable layers in combination with lazy construction is currently
"""Using skippable layers in combination with lazy construction is currently
not supported, check that it raises an Exception"""
not supported, check that it raises an Exception"""
...
@@ -181,6 +175,6 @@ def lazy_skippable_error(pipeline_style):
...
@@ -181,6 +175,6 @@ def lazy_skippable_error(pipeline_style):
]
]
with
pytest
.
raises
(
ValueError
,
match
=
"Can't use Skippable layers with multi-process pipe and lazy construction"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Can't use Skippable layers with multi-process pipe and lazy construction"
):
MultiProcessPipe
(
pipe_class
(
model
,
[
2
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
model
,
[
2
,
1
],
worker_map
=
get_worker_map
(),
)
)
tests/nn/pipe_process/skip/test_leak.py
View file @
eaee5976
...
@@ -23,7 +23,7 @@ import pytest
...
@@ -23,7 +23,7 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn.pipe
import
MultiProcessPipe
,
is_checkpointing
,
is_recomputing
from
fairscale.nn.pipe
import
AsyncPipe
,
MultiProcessPipe
,
is_checkpointing
,
is_recomputing
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip.tracker
import
current_skip_tracker
from
fairscale.nn.pipe.skip.tracker
import
current_skip_tracker
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
...
@@ -46,10 +46,10 @@ class Pop(nn.Module):
...
@@ -46,10 +46,10 @@ class Pop(nn.Module):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"train"
,
[
True
,
False
],
ids
=
[
"train"
,
"eval"
])
@
pytest
.
mark
.
parametrize
(
"train"
,
[
True
,
False
],
ids
=
[
"train"
,
"eval"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"always"
,
"except_last"
,
"never"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"always"
,
"except_last"
,
"never"
])
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
def
delete_portal_tensor
(
train
,
checkpoint
,
pipe
line_style
):
def
delete_portal_tensor
(
train
,
checkpoint
,
pipe
_class
):
# Without checkpointing:
# Without checkpointing:
# +- Stash --+ +--- Pop ----+ - - - layers
# +- Stash --+ +--- Pop ----+ - - - layers
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
...
@@ -60,8 +60,8 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
...
@@ -60,8 +60,8 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+
# +----------+ +------------+ +------------+ +----------+
if
pipe
line_style
==
MultiProcessPipe
.
AsyncSchedul
e
:
if
pipe
_class
==
AsyncPip
e
:
pytest
.
skip
(
"Skip tensors NYI for Async
Schedul
e"
)
pytest
.
skip
(
"Skip tensors NYI for Async
Pip
e"
)
def
portal_tensor_life_is
(
tensor_life
,
skip_tracker
=
None
):
def
portal_tensor_life_is
(
tensor_life
,
skip_tracker
=
None
):
if
skip_tracker
is
None
:
if
skip_tracker
is
None
:
...
@@ -114,9 +114,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
...
@@ -114,9 +114,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
return
self
.
F
.
apply
(
input
)
return
self
.
F
.
apply
(
input
)
model
=
nn
.
Sequential
(
NoPortalTensorAtBackward
(),
stash_
,
pop_
)
model
=
nn
.
Sequential
(
NoPortalTensorAtBackward
(),
stash_
,
pop_
)
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
balance
=
[
2
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
checkpoint
,)
model
,
balance
=
[
2
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
checkpoint
,
)
input
=
torch
.
rand
(
10
,
requires_grad
=
True
)
input
=
torch
.
rand
(
10
,
requires_grad
=
True
)
...
...
tests/nn/pipe_process/test_bugs.py
View file @
eaee5976
...
@@ -22,15 +22,15 @@ import torch
...
@@ -22,15 +22,15 @@ import torch
from
torch
import
nn
from
torch
import
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
fairscale.nn.pipe
import
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
,
MultiProcessPipe
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
python_autograd_function
(
pipe
line_style
):
def
python_autograd_function
(
pipe
_class
):
# FIXME deadlock with
MultiProcessPipe.AsyncSchedul
e?
# FIXME deadlock with
AsyncPip
e?
# A Python autograd function might fail with this error:
# A Python autograd function might fail with this error:
#
#
# RuntimeError: Returning Variables sharing storage with other Variables
# RuntimeError: Returning Variables sharing storage with other Variables
...
@@ -57,9 +57,7 @@ def python_autograd_function(pipeline_style):
...
@@ -57,9 +57,7 @@ def python_autograd_function(pipeline_style):
return
Identity
.
apply
(
input
)
return
Identity
.
apply
(
input
)
model
=
nn
.
Sequential
(
M
(),
M
())
model
=
nn
.
Sequential
(
M
(),
M
())
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
[
1
,
1
],
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
).
cuda
()
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
).
cuda
()
model
.
eval
()
model
.
eval
()
x
=
torch
.
rand
(
42
)
x
=
torch
.
rand
(
42
)
...
@@ -73,8 +71,8 @@ def python_autograd_function(pipeline_style):
...
@@ -73,8 +71,8 @@ def python_autograd_function(pipeline_style):
@
torch_spawn
([
3
])
@
torch_spawn
([
3
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
exception_no_hang
(
pipe
line_style
):
def
exception_no_hang
(
pipe
_class
):
# In v0.0.2, once a failed partition receives a normal message
# In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was
# (non-closing) for the next micro-batch, a hang occured. The reason was
# that a failed partition didn't call in_queue.task_done() on a normal
# that a failed partition didn't call in_queue.task_done() on a normal
...
@@ -92,7 +90,7 @@ def exception_no_hang(pipeline_style):
...
@@ -92,7 +90,7 @@ def exception_no_hang(pipeline_style):
raise
ExpectedException
()
raise
ExpectedException
()
model
=
nn
.
Sequential
(
Pass
(),
Pass
(),
Raise
())
model
=
nn
.
Sequential
(
Pass
(),
Pass
(),
Raise
())
model
=
MultiProcessPipe
(
model
,
[
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
3
)
model
=
pipe_class
(
model
,
[
1
,
1
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
3
)
model
.
eval
()
model
.
eval
()
if
model
.
group
.
rank
()
==
2
:
if
model
.
group
.
rank
()
==
2
:
...
@@ -106,8 +104,8 @@ def exception_no_hang(pipeline_style):
...
@@ -106,8 +104,8 @@ def exception_no_hang(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"2 cuda devices required"
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"2 cuda devices required"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
tuple_wait
(
cuda_sleep
,
pipe
line_style
):
def
tuple_wait
(
cuda_sleep
,
pipe
_class
):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility
# Under this behavior, if checkpointing was disabled, there's a possibility
# that gradient accumulations on other tensors are not synchronized
# that gradient accumulations on other tensors are not synchronized
...
@@ -135,10 +133,9 @@ def tuple_wait(cuda_sleep, pipeline_style):
...
@@ -135,10 +133,9 @@ def tuple_wait(cuda_sleep, pipeline_style):
return
a
+
b
+
c
return
a
+
b
+
c
model
=
nn
.
Sequential
(
Layer1
(),
Layer2
())
model
=
nn
.
Sequential
(
Layer1
(),
Layer2
())
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
model
,
[
1
,
1
],
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
input_device
=
torch
.
cuda
.
current_device
(),
input_device
=
torch
.
cuda
.
current_device
(),
chunks
=
32
,
chunks
=
32
,
...
@@ -160,8 +157,8 @@ def tuple_wait(cuda_sleep, pipeline_style):
...
@@ -160,8 +157,8 @@ def tuple_wait(cuda_sleep, pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
parallel_randoms
(
pipe
line_style
):
def
parallel_randoms
(
pipe
_class
):
class
Dropouts
(
nn
.
Module
):
class
Dropouts
(
nn
.
Module
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
for
_
in
range
(
100
):
for
_
in
range
(
100
):
...
@@ -172,10 +169,9 @@ def parallel_randoms(pipeline_style):
...
@@ -172,10 +169,9 @@ def parallel_randoms(pipeline_style):
x
=
torch
.
rand
(
10
,
10
,
requires_grad
=
True
).
cuda
()
x
=
torch
.
rand
(
10
,
10
,
requires_grad
=
True
).
cuda
()
x
.
retain_grad
()
x
.
retain_grad
()
model
=
MultiProcessPipe
(
model
=
pipe_class
(
model
,
model
,
[
1
,
1
],
[
1
,
1
],
style
=
pipeline_style
,
input_device
=
torch
.
cuda
.
current_device
(),
input_device
=
torch
.
cuda
.
current_device
(),
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
chunks
=
10
,
chunks
=
10
,
...
...
tests/nn/pipe_process/test_inplace.py
View file @
eaee5976
...
@@ -21,21 +21,21 @@ import pytest
...
@@ -21,21 +21,21 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn.pipe
import
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
,
MultiProcessPipe
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
inplace_on_requires_grad
(
pipe
line_style
):
def
inplace_on_requires_grad
(
pipe
_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
ReLU
(
inplace
=
True
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
ReLU
(
inplace
=
True
))
model
=
MultiProcessPipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
model
=
pipe_class
(
model
,
[
1
,
1
]
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
x
=
torch
.
rand
(
1
)
x
=
torch
.
rand
(
1
)
if
pipe
line_style
==
MultiProcessPipe
.
AsyncSchedul
e
and
model
.
group
.
rank
()
==
0
:
if
pipe
_class
==
AsyncPip
e
and
model
.
group
.
rank
()
==
0
:
# With Async
Schedul
e, model will wait forever for gradients if not eval
# With Async
Pip
e, model will wait forever for gradients if not eval
model
.
eval
()
model
.
eval
()
y
=
model
(
x
)
y
=
model
(
x
)
...
@@ -50,12 +50,12 @@ def inplace_on_requires_grad(pipeline_style):
...
@@ -50,12 +50,12 @@ def inplace_on_requires_grad(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
inplace_on_not_requires_grad
(
pipe
line_style
):
def
inplace_on_not_requires_grad
(
pipe
_class
):
# In-place operation on a tensor not requiring grad doesn't cause a
# In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case.
# RuntimeError. Currently, we cannot detect this case.
model
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
))
model
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
))
model
=
MultiProcessPipe
(
model
,
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
model
=
pipe_class
(
model
,
[
1
]
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
x
=
torch
.
rand
(
1
)
x
=
torch
.
rand
(
1
)
y
=
model
(
x
)
y
=
model
(
x
)
...
@@ -70,8 +70,8 @@ def inplace_on_not_requires_grad(pipeline_style):
...
@@ -70,8 +70,8 @@ def inplace_on_not_requires_grad(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
inplace_incorrect_grad
(
pipe
line_style
):
def
inplace_incorrect_grad
(
pipe
_class
):
class
M
(
nn
.
Module
):
class
M
(
nn
.
Module
):
def
forward
(
self
,
foo_bar
):
def
forward
(
self
,
foo_bar
):
# 'foo' requires grad but 'bar' does not. In-place operation on
# 'foo' requires grad but 'bar' does not. In-place operation on
...
@@ -88,7 +88,7 @@ def inplace_incorrect_grad(pipeline_style):
...
@@ -88,7 +88,7 @@ def inplace_incorrect_grad(pipeline_style):
return
foo
*
bar
return
foo
*
bar
model
=
nn
.
Sequential
(
M
())
model
=
nn
.
Sequential
(
M
())
model
=
MultiProcessPipe
(
model
,
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
model
=
pipe_class
(
model
,
[
1
]
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
foo
=
torch
.
tensor
([
1.0
],
requires_grad
=
True
)
foo
=
torch
.
tensor
([
1.0
],
requires_grad
=
True
)
bar
=
torch
.
tensor
([
1.0
])
bar
=
torch
.
tensor
([
1.0
])
...
...
tests/nn/pipe_process/test_pipe.py
View file @
eaee5976
This diff is collapsed.
Click to expand it.
tests/nn/pipe_process/test_transparency.py
View file @
eaee5976
...
@@ -21,14 +21,14 @@ import pytest
...
@@ -21,14 +21,14 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn.pipe
import
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
,
MultiProcessPipe
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe
line_style
"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedul
e
])
@
pytest
.
mark
.
parametrize
(
"pipe
_class
"
,
[
MultiProcessPipe
,
AsyncPip
e
])
def
simple_linears
(
pipe
line_style
):
def
simple_linears
(
pipe
_class
):
def
sum_grad
(
parameters
):
def
sum_grad
(
parameters
):
return
sum
([
p
.
grad
.
sum
()
for
p
in
parameters
if
p
.
grad
is
not
None
])
return
sum
([
p
.
grad
.
sum
()
for
p
in
parameters
if
p
.
grad
is
not
None
])
...
@@ -54,8 +54,7 @@ def simple_linears(pipeline_style):
...
@@ -54,8 +54,7 @@ def simple_linears(pipeline_style):
zero_grad
(
model
.
parameters
())
zero_grad
(
model
.
parameters
())
# With MultiProcessPipe
model
=
pipe_class
(
model
,
[
2
,
2
],
worker_map
=
get_worker_map
(),
chunks
=
4
)
model
=
MultiProcessPipe
(
model
,
[
2
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
outputs
=
model
(
inputs
)
outputs
=
model
(
inputs
)
if
model
.
group
.
rank
()
==
1
:
if
model
.
group
.
rank
()
==
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment