Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
cae9b638
Unverified
Commit
cae9b638
authored
Jan 26, 2021
by
msbaines
Committed by
GitHub
Jan 26, 2021
Browse files
[refactor] pipe: separate out Single and MultiProcess pipe (#326)
parent
eab1551a
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1347 additions
and
959 deletions
+1347
-959
benchmarks/experimental_ampnet.py
benchmarks/experimental_ampnet.py
+2
-3
benchmarks/pipe.py
benchmarks/pipe.py
+6
-8
examples/tutorial_pipe_multiprocess.py
examples/tutorial_pipe_multiprocess.py
+3
-3
experimental/nn/ampnet_pipe/pipe.py
experimental/nn/ampnet_pipe/pipe.py
+4
-4
experimental/tests/nn/ampnet_pipe_process/test_ampnet_pipe.py
...rimental/tests/nn/ampnet_pipe_process/test_ampnet_pipe.py
+3
-3
fairscale/nn/__init__.py
fairscale/nn/__init__.py
+1
-1
fairscale/nn/pipe/__init__.py
fairscale/nn/pipe/__init__.py
+2
-1
fairscale/nn/pipe/async_schedule.py
fairscale/nn/pipe/async_schedule.py
+1
-2
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+660
-0
fairscale/nn/pipe/multiprocess_pipeline.py
fairscale/nn/pipe/multiprocess_pipeline.py
+440
-0
fairscale/nn/pipe/pipe.py
fairscale/nn/pipe/pipe.py
+78
-460
fairscale/nn/pipe/pipeline.py
fairscale/nn/pipe/pipeline.py
+97
-428
fairscale/nn/pipe/rpc.py
fairscale/nn/pipe/rpc.py
+10
-8
fairscale/nn/pipe/types.py
fairscale/nn/pipe/types.py
+0
-1
stubs/torch/multiprocessing/__init__.pyi
stubs/torch/multiprocessing/__init__.pyi
+1
-1
tests/nn/model_parallel/test_layers.py
tests/nn/model_parallel/test_layers.py
+5
-5
tests/nn/pipe_process/skip/test_gpipe.py
tests/nn/pipe_process/skip/test_gpipe.py
+10
-9
tests/nn/pipe_process/skip/test_leak.py
tests/nn/pipe_process/skip/test_leak.py
+4
-4
tests/nn/pipe_process/test_bugs.py
tests/nn/pipe_process/test_bugs.py
+12
-10
tests/nn/pipe_process/test_inplace.py
tests/nn/pipe_process/test_inplace.py
+8
-8
No files found.
benchmarks/experimental_ampnet.py
View file @
cae9b638
...
@@ -19,10 +19,9 @@ import torchtext
...
@@ -19,10 +19,9 @@ import torchtext
from
torchtext.data.utils
import
get_tokenizer
from
torchtext.data.utils
import
get_tokenizer
from
experimental.nn.ampnet_pipe
import
pipe
from
experimental.nn.ampnet_pipe
import
pipe
from
fairscale.nn
import
Pipe
from
fairscale.nn.model_parallel
import
initialize_model_parallel
from
fairscale.nn.model_parallel
import
initialize_model_parallel
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
fairscale.nn.pipe
import
LazyModule
from
fairscale.nn.pipe
import
LazyModule
,
MultiProcessPipe
from
fairscale.optim
import
GradScaler
from
fairscale.optim
import
GradScaler
from
fairscale.utils.testing
import
dist_init
,
get_worker_map
from
fairscale.utils.testing
import
dist_init
,
get_worker_map
...
@@ -421,7 +420,7 @@ def run_mp_worker(args, available_workers):
...
@@ -421,7 +420,7 @@ def run_mp_worker(args, available_workers):
p
=
pipe
.
AMPnetPipe
(
p
=
pipe
.
AMPnetPipe
(
module
=
model
,
module
=
model
,
balance
=
balance
,
balance
=
balance
,
style
=
Pipe
.
AsyncSchedule
,
style
=
MultiProcess
Pipe
.
AsyncSchedule
,
chunks
=
args
.
chunks
,
chunks
=
args
.
chunks
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
...
...
benchmarks/pipe.py
View file @
cae9b638
...
@@ -25,7 +25,7 @@ from torch.optim import Adam
...
@@ -25,7 +25,7 @@ from torch.optim import Adam
from
fairscale.nn
import
Pipe
from
fairscale.nn
import
Pipe
from
fairscale.nn.model_parallel
import
initialize_model_parallel
from
fairscale.nn.model_parallel
import
initialize_model_parallel
from
fairscale.nn.model_parallel.initialize
import
get_data_parallel_group
,
get_pipeline_parallel_group
from
fairscale.nn.model_parallel.initialize
import
get_data_parallel_group
,
get_pipeline_parallel_group
from
fairscale.nn.pipe
import
LazyModule
,
p
ipe
from
fairscale.nn.pipe
import
LazyModule
,
MultiProcessP
ipe
from
fairscale.optim.oss
import
OSS
from
fairscale.optim.oss
import
OSS
from
fairscale.utils.testing
import
dist_init
,
get_worker_map
from
fairscale.utils.testing
import
dist_init
,
get_worker_map
...
@@ -157,7 +157,7 @@ def dump_cuda_tensors():
...
@@ -157,7 +157,7 @@ def dump_cuda_tensors():
def
log_number_of_parameters
(
model
):
def
log_number_of_parameters
(
model
):
num_params
=
reduce
(
operator
.
add
,
(
reduce
(
operator
.
mul
,
x
.
size
())
for
x
in
model
.
parameters
()))
num_params
=
reduce
(
operator
.
add
,
(
reduce
(
operator
.
mul
,
x
.
size
())
for
x
in
model
.
parameters
()))
if
model
.
group
:
if
hasattr
(
model
,
"
group
"
)
:
total
=
torch
.
Tensor
([
num_params
])
total
=
torch
.
Tensor
([
num_params
])
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
total
=
total
.
cuda
()
total
=
total
.
cuda
()
...
@@ -212,7 +212,7 @@ def train(model_config, model, benchmark_config, args):
...
@@ -212,7 +212,7 @@ def train(model_config, model, benchmark_config, args):
optimizer
=
optimizer
(
model
.
parameters
())
optimizer
=
optimizer
(
model
.
parameters
())
pipe_group
=
model
.
group
pipe_group
=
model
.
group
if
hasattr
(
model
,
"group"
)
else
None
if
args
.
ddp_zero
:
if
args
.
ddp_zero
:
model
=
DDP
(
model
=
DDP
(
...
@@ -479,9 +479,7 @@ def benchmark_single_process(args):
...
@@ -479,9 +479,7 @@ def benchmark_single_process(args):
model
=
model_config
[
"model"
]
model
=
model_config
[
"model"
]
balance
=
generate_balance
(
min
(
num_devices
,
4
),
len
(
model
))
balance
=
generate_balance
(
min
(
num_devices
,
4
),
len
(
model
))
pipe_model
=
pipe
.
Pipe
(
pipe_model
=
Pipe
(
model
,
balance
,
chunks
=
args
.
chunks
,
checkpoint
=
args
.
checkpoint
)
model
,
balance
,
chunks
=
args
.
chunks
,
pipelined_backward
=
args
.
pipelined_backward
,
checkpoint
=
args
.
checkpoint
)
del
model
del
model
del
model_config
[
"model"
]
del
model_config
[
"model"
]
...
@@ -498,10 +496,10 @@ def run_mp_worker(args, available_workers):
...
@@ -498,10 +496,10 @@ def run_mp_worker(args, available_workers):
model
=
model_config
[
"model"
]
model
=
model_config
[
"model"
]
balance
=
generate_balance_weighted
(
get_pipeline_parallel_group
().
size
(),
len
(
model
),
0.8
)
balance
=
generate_balance_weighted
(
get_pipeline_parallel_group
().
size
(),
len
(
model
),
0.8
)
pipe_model
=
pipe
.
Pipe
(
pipe_model
=
MultiProcess
Pipe
(
model
,
model
,
balance
,
balance
,
style
=
Pipe
.
AsyncSchedule
,
style
=
MultiProcess
Pipe
.
AsyncSchedule
,
chunks
=
args
.
chunks
,
chunks
=
args
.
chunks
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
...
...
examples/tutorial_pipe_multiprocess.py
View file @
cae9b638
...
@@ -6,8 +6,8 @@ import torch.distributed as dist
...
@@ -6,8 +6,8 @@ import torch.distributed as dist
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.optim
as
optim
import
torch.optim
as
optim
import
fairscale
from
fairscale.nn.model_parallel
import
initialize_model_parallel
from
fairscale.nn.model_parallel
import
initialize_model_parallel
from
fairscale.nn.pipe
import
MultiProcessPipe
DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
RANK
=
0
# example
RANK
=
0
# example
...
@@ -27,10 +27,10 @@ def run(rank, world_size):
...
@@ -27,10 +27,10 @@ def run(rank, world_size):
device
=
torch
.
device
(
"cuda"
,
RANK
)
if
DEVICE
==
"cuda"
else
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cuda"
,
RANK
)
if
DEVICE
==
"cuda"
else
torch
.
device
(
"cpu"
)
model
=
fairscale
.
nn
.
Pipe
(
model
=
MultiProcess
Pipe
(
model
,
model
,
balance
=
[
2
,
1
],
balance
=
[
2
,
1
],
style
=
fairscale
.
nn
.
Pipe
.
MultiProcess
,
style
=
MultiProcess
Pipe
.
MultiProcess
,
worker_map
=
{
0
:
"worker0"
,
1
:
"worker1"
},
# Needed to convert ranks to RPC worker names
worker_map
=
{
0
:
"worker0"
,
1
:
"worker1"
},
# Needed to convert ranks to RPC worker names
input_device
=
device
,
input_device
=
device
,
).
to
(
device
)
).
to
(
device
)
...
...
experimental/nn/ampnet_pipe/pipe.py
View file @
cae9b638
...
@@ -11,7 +11,7 @@ from torch import nn
...
@@ -11,7 +11,7 @@ from torch import nn
from
torch.optim.optimizer
import
Optimizer
from
torch.optim.optimizer
import
Optimizer
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
fairscale.nn.pipe
import
Pipe
from
fairscale.nn.pipe
import
MultiProcess
Pipe
from
fairscale.nn.pipe.types
import
PipelineStyle
from
fairscale.nn.pipe.types
import
PipelineStyle
from
.ampnet
import
AsyncAMPnetEventLoop
from
.ampnet
import
AsyncAMPnetEventLoop
...
@@ -19,9 +19,9 @@ from .ampnet import AsyncAMPnetEventLoop
...
@@ -19,9 +19,9 @@ from .ampnet import AsyncAMPnetEventLoop
__all__
=
[
"AMPnetPipe"
]
__all__
=
[
"AMPnetPipe"
]
class
AMPnetPipe
(
Pipe
):
class
AMPnetPipe
(
MultiProcess
Pipe
):
"""
"""
AMPnetPipe is the asynchronous version of the Pipe implementation
AMPnetPipe is the asynchronous version of the
MultiProcess
Pipe implementation
which avoids the bubble issue, by using stale weights and gradients.
which avoids the bubble issue, by using stale weights and gradients.
The implementation closely follows the paper: https://arxiv.org/abs/1705.09786
The implementation closely follows the paper: https://arxiv.org/abs/1705.09786
"""
"""
...
@@ -39,7 +39,7 @@ class AMPnetPipe(Pipe):
...
@@ -39,7 +39,7 @@ class AMPnetPipe(Pipe):
weight_prediction
:
bool
=
False
,
weight_prediction
:
bool
=
False
,
)
->
None
:
)
->
None
:
partitions
=
self
.
mp_
partitions
partitions
=
self
.
partitions
n
=
len
(
partitions
)
n
=
len
(
partitions
)
# AMPnet implementation doesn't handle skip_trackers!
# AMPnet implementation doesn't handle skip_trackers!
...
...
experimental/tests/nn/ampnet_pipe_process/test_ampnet_pipe.py
View file @
cae9b638
...
@@ -23,7 +23,7 @@ from torch.optim.optimizer import Optimizer
...
@@ -23,7 +23,7 @@ from torch.optim.optimizer import Optimizer
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data
import
DataLoader
,
Dataset
from
experimental.nn.ampnet_pipe.pipe
import
AMPnetPipe
from
experimental.nn.ampnet_pipe.pipe
import
AMPnetPipe
from
fairscale.nn.pipe
import
Pipe
from
fairscale.nn.pipe
import
MultiProcess
Pipe
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
...
@@ -87,7 +87,7 @@ def async_event_loop_interleave_simple():
...
@@ -87,7 +87,7 @@ def async_event_loop_interleave_simple():
pipe
=
AMPnetPipe
(
pipe
=
AMPnetPipe
(
module
=
model
,
module
=
model
,
balance
=
[
2
,
2
],
balance
=
[
2
,
2
],
style
=
Pipe
.
AsyncSchedule
,
style
=
MultiProcess
Pipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
chunks
=
10
,
chunks
=
10
,
checkpoint
=
"never"
,
checkpoint
=
"never"
,
...
@@ -105,7 +105,7 @@ def async_event_loop_interleave_hard():
...
@@ -105,7 +105,7 @@ def async_event_loop_interleave_hard():
pipe
=
AMPnetPipe
(
pipe
=
AMPnetPipe
(
module
=
model
,
module
=
model
,
balance
=
[
1
,
1
,
1
,
1
],
balance
=
[
1
,
1
,
1
,
1
],
style
=
Pipe
.
AsyncSchedule
,
style
=
MultiProcess
Pipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
chunks
=
10
,
chunks
=
10
,
checkpoint
=
"never"
,
checkpoint
=
"never"
,
...
...
fairscale/nn/__init__.py
View file @
cae9b638
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
from
.data_parallel
import
ShardedDataParallel
from
.data_parallel
import
ShardedDataParallel
from
.misc
import
FlattenParamsWrapper
from
.misc
import
FlattenParamsWrapper
from
.moe
import
MOELayer
,
Top2Gate
from
.moe
import
MOELayer
,
Top2Gate
from
.pipe
import
LazyModule
,
Pipe
,
PipeRPCWrapper
from
.pipe
import
Pipe
,
PipeRPCWrapper
__all__
=
[
__all__
=
[
"FlattenParamsWrapper"
,
"FlattenParamsWrapper"
,
...
...
fairscale/nn/pipe/__init__.py
View file @
cae9b638
...
@@ -19,7 +19,8 @@
...
@@ -19,7 +19,8 @@
"""A Pipe implementation in PyTorch."""
"""A Pipe implementation in PyTorch."""
from
.checkpoint
import
is_checkpointing
,
is_recomputing
from
.checkpoint
import
is_checkpointing
,
is_recomputing
from
.pipe
import
LazyModule
,
Pipe
from
.multiprocess_pipe
import
LazyModule
,
MultiProcessPipe
from
.pipe
import
Pipe
from
.rpc
import
PipeRPCWrapper
from
.rpc
import
PipeRPCWrapper
__all__
=
[
"Pipe"
,
"is_checkpointing"
,
"is_recomputing"
,
"LazyModule"
]
__all__
=
[
"Pipe"
,
"is_checkpointing"
,
"is_recomputing"
,
"LazyModule"
]
fairscale/nn/pipe/async_schedule.py
View file @
cae9b638
...
@@ -191,7 +191,7 @@ class AsyncEventLoop:
...
@@ -191,7 +191,7 @@ class AsyncEventLoop:
"""Actually run the forward pass for a given module, and send the result
"""Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed."""
to the next stage in the pipeline if needed."""
assert
self
.
group
assert
self
.
group
from
.pipeline
import
create_task
from
.
multiprocess_
pipeline
import
create_task
task
=
create_task
(
task
=
create_task
(
PipelineStyle
.
AsyncSchedule
,
PipelineStyle
.
AsyncSchedule
,
...
@@ -201,7 +201,6 @@ class AsyncEventLoop:
...
@@ -201,7 +201,6 @@ class AsyncEventLoop:
batch
,
batch
,
partition
.
module
,
partition
.
module
,
skip_trackers
,
skip_trackers
,
[],
)
)
result
=
task
.
compute
()
result
=
task
.
compute
()
task
.
finalize
(
result
)
task
.
finalize
(
result
)
...
...
fairscale/nn/pipe/multiprocess_pipe.py
0 → 100644
View file @
cae9b638
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The MultiProcessPipe interface."""
from
collections
import
OrderedDict
from
dataclasses
import
dataclass
,
field
import
itertools
import
threading
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
import
warnings
import
torch
from
torch
import
Tensor
,
nn
import
torch.autograd
import
torch.cuda
from
fairscale.nn.model_parallel
import
get_model_parallel_world_size
,
get_pipeline_parallel_group
from
.
import
microbatch
from
.async_schedule
import
Invocation
,
Location
,
ModuleWrapper
from
.batchnorm
import
DeferredBatchNorm
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.skip.layout
import
SkipLayout
,
inspect_skip_layout
from
.skip.skippable
import
Skippable
,
verify_skippables
from
.types
import
LazyModule
,
PipelineStyle
__all__
=
[
"MultiProcessPipe"
,
"LazyModule"
]
Device
=
Union
[
torch
.
device
,
int
,
str
]
Devices
=
Union
[
Iterable
[
Device
],
List
[
Device
]]
Tensors
=
Tuple
[
Tensor
,
...]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
ListOfLazyModules
=
List
[
LazyModule
]
if
TYPE_CHECKING
:
Module
=
nn
.
Module
[
TensorOrTensors
]
NamedModules
=
OrderedDict
[
str
,
Module
]
else
:
Module
=
nn
.
Module
NamedModules
=
OrderedDict
def
recommend_auto_balance
(
message
:
str
)
->
str
:
"""Expands a message with recommendation to :mod:`torchpipe.balance`."""
return
f
"""
{
message
}
If your model is still under development, its optimal balance would change
frequently. In this case, we highly recommend 'fairscale.nn.pipe.balance' for
naive automatic balancing:
from fairscale.nn import Pipe
from fairscale.nn.pipe.balance import balance_by_time
partitions = torch.cuda.device_count()
sample = torch.empty(...)
balance = balance_by_time(partitions, model, sample)
model = MultiProcessPipe(model, balance, ...)
"""
# FIXME(tom) make this a valid way to call
def
verify_list_of_callable
(
module
:
Union
[
nn
.
Sequential
,
list
])
->
None
:
for
layer
in
module
:
if
isinstance
(
layer
,
nn
.
Module
):
pass
elif
isinstance
(
layer
,
LazyModule
):
pass
else
:
raise
TypeError
(
f
"layer
{
type
(
layer
)
}
must be nn.Module or LazyModule to be partitioned"
)
def
verify_module
(
module
:
Union
[
nn
.
Sequential
,
ListOfLazyModules
])
->
None
:
if
isinstance
(
module
,
Iterable
)
and
not
isinstance
(
module
,
nn
.
Sequential
):
verify_list_of_callable
(
module
)
else
:
if
not
isinstance
(
module
,
nn
.
Sequential
):
raise
TypeError
(
"module must be nn.Sequential to be partitioned"
)
named_children
=
list
(
module
.
named_children
())
if
len
(
named_children
)
!=
len
(
module
):
raise
ValueError
(
"module with duplicate children is not supported"
)
def
verify_splitting
(
module
:
nn
.
Sequential
,
partitions
:
List
[
nn
.
Sequential
],
balance
:
Iterable
[
int
],)
->
None
:
num_parameters
=
len
(
list
(
module
.
parameters
()))
num_child_parameters
=
sum
(
len
(
list
(
child
.
parameters
()))
for
child
in
module
.
children
())
if
num_parameters
==
num_child_parameters
:
return
for
i
in
range
(
len
(
partitions
)):
for
j
in
range
(
i
+
1
,
len
(
partitions
)):
parti
=
partitions
[
i
]
partj
=
partitions
[
j
]
for
p
in
parti
.
parameters
():
for
q
in
partj
.
parameters
():
if
p
is
q
:
raise
ValueError
(
"module with duplicate parameters on distinct devices is not supported"
)
class
BalanceError
(
ValueError
):
pass
def
check_balance
(
module
:
Any
,
balance
:
Iterable
[
int
],
filter_unique
:
bool
=
False
)
->
None
:
if
filter_unique
:
module_len
=
len
(
set
(
map
(
id
,
module
)))
else
:
module_len
=
len
(
module
)
if
module_len
!=
sum
(
balance
):
raise
BalanceError
(
f
"module and sum of balance have different length (module:
{
len
(
module
)
}
, sum of balance:
{
sum
(
balance
)
}
)"
)
if
any
(
x
<=
0
for
x
in
balance
):
raise
BalanceError
(
f
"all balance numbers must be positive integer (balance:
{
balance
}
)"
)
@
dataclass
class
PartitionInfo
:
location
:
Location
modules
:
"OrderedDict[str, nn.Module]"
invocations
:
List
[
Invocation
]
=
field
(
default_factory
=
list
)
def
__len__
(
self
)
->
int
:
return
len
(
self
.
modules
)
def
instantiate_partition
(
module
:
Union
[
nn
.
Sequential
,
ListOfLazyModules
],
balance
:
Iterable
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
style
:
PipelineStyle
,
)
->
List
[
ModuleWrapper
]:
balance
=
list
(
balance
)
check_balance
(
module
,
balance
,
True
)
layers
:
NamedModules
=
OrderedDict
()
def
maybe_realize
(
layer
:
Any
)
->
nn
.
Module
:
if
isinstance
(
layer
,
nn
.
Module
):
return
layer
elif
callable
(
layer
):
return
layer
()
else
:
raise
TypeError
(
f
"layer must be nn.Module or callable, is
{
type
(
layer
)
}
"
)
def
iterate_module
(
module
:
Union
[
nn
.
Sequential
,
list
])
->
Iterable
[
Tuple
[
Any
,
nn
.
Module
]]:
if
isinstance
(
module
,
nn
.
Sequential
):
yield
from
module
.
named_children
()
else
:
yield
from
((
str
(
k
),
v
)
for
k
,
v
in
enumerate
(
module
))
if
style
==
PipelineStyle
.
AsyncSchedule
:
module_ids
=
list
(
map
(
id
,
module
))
index_of_first_use
=
[
module_ids
.
index
(
x
)
for
x
in
module_ids
]
locations
:
List
[
Location
]
=
[]
module_iter
=
enumerate
(
iterate_module
(
module
))
partitions
:
List
[
List
[
PartitionInfo
]]
=
[]
for
bi
,
b
in
enumerate
(
balance
):
modules_for_rank
:
List
[
PartitionInfo
]
=
[]
current_module
:
OrderedDict
[
str
,
nn
.
Module
]
=
OrderedDict
()
def
current_location
()
->
Location
:
return
Location
(
bi
,
len
(
modules_for_rank
))
def
append_module
(
mod
:
"OrderedDict[str, nn.Module]"
)
->
None
:
modules_for_rank
.
append
(
PartitionInfo
(
current_location
(),
mod
))
while
sum
(
map
(
len
,
modules_for_rank
))
+
len
(
current_module
)
<
b
:
module_index
,
(
name
,
layer
)
=
next
(
module_iter
)
if
index_of_first_use
[
module_index
]
!=
module_index
:
# Subsequent reuse of a module
locations
.
append
(
locations
[
index_of_first_use
[
module_index
]])
continue
is_reused
=
index_of_first_use
.
count
(
index_of_first_use
[
module_index
])
>
1
if
is_reused
and
len
(
current_module
)
>
0
:
append_module
(
current_module
)
current_module
=
OrderedDict
()
current_module
[
str
(
name
)]
=
layer
locations
.
append
(
current_location
())
if
is_reused
:
append_module
(
current_module
)
current_module
=
OrderedDict
()
if
len
(
current_module
)
>
0
:
append_module
(
current_module
)
partitions
.
append
(
modules_for_rank
)
filtered_locations
:
List
[
Optional
[
Location
]]
=
[
loc
for
loc
,
_
in
itertools
.
groupby
(
locations
)]
filtered_locations
.
append
(
None
)
for
i
in
range
(
len
(
filtered_locations
)
-
1
):
loc
=
filtered_locations
[
i
]
assert
loc
if
i
==
0
:
inv
=
Invocation
(
i
,
loc
,
None
,
filtered_locations
[
i
+
1
])
else
:
inv
=
Invocation
(
i
,
loc
,
filtered_locations
[
i
-
1
],
filtered_locations
[
i
+
1
])
partitions
[
loc
.
stage
][
loc
.
index
].
invocations
.
append
(
inv
)
invocations
=
enumerate
(
iterate_module
(
module
))
partition
=
partitions
[
group
.
rank
()]
result
:
List
[
ModuleWrapper
]
=
[]
for
partition_info
in
partition
:
wrapper
=
ModuleWrapper
(
nn
.
Sequential
(
OrderedDict
((
k
,
maybe_realize
(
m
))
for
k
,
m
in
partition_info
.
modules
.
items
())),
partition_info
.
location
,
partition_info
.
invocations
,
)
if
not
isinstance
(
module
,
nn
.
Sequential
):
for
layer
in
wrapper
.
module
:
if
isinstance
(
layer
,
Skippable
):
raise
ValueError
(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
result
.
append
(
wrapper
)
return
result
j
=
0
for
name
,
layer
in
iterate_module
(
module
):
layers
[
name
]
=
layer
if
len
(
layers
)
==
balance
[
j
]:
if
j
==
group
.
rank
():
for
key
in
layers
:
layers
[
key
]
=
maybe_realize
(
layers
[
key
])
if
not
isinstance
(
module
,
nn
.
Sequential
):
for
layer
in
layers
.
values
():
if
isinstance
(
layer
,
Skippable
):
raise
ValueError
(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
return
[
ModuleWrapper
(
nn
.
Sequential
(
layers
),
Location
(
j
,
0
))]
# Prepare for the next partition.
layers
.
clear
()
j
+=
1
raise
ValueError
(
"Souldn't get here, more ranks than partitions"
)
def
split_module
(
module
:
nn
.
Sequential
,
balance
:
Iterable
[
int
],)
->
Tuple
[
List
[
nn
.
Sequential
],
List
[
int
]]:
"""Splits a module into multiple partitions.
Returns:
A tuple of (partitions, balance).
Partitions are represented as a :class:`~torch.nn.ModuleList` whose
item is a partition. All layers in a partition are placed in the
same device.
Raises:
BalanceError:
wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
balance
=
list
(
balance
)
check_balance
(
module
,
balance
)
j
=
0
partitions
=
[]
layers
:
NamedModules
=
OrderedDict
()
for
name
,
layer
in
module
.
named_children
():
layers
[
name
]
=
layer
if
len
(
layers
)
==
balance
[
j
]:
# Group buffered layers as a partition.
partition
=
nn
.
Sequential
(
layers
)
partitions
.
append
(
partition
)
# Prepare for the next partition.
layers
.
clear
()
j
+=
1
partitions
=
cast
(
List
[
nn
.
Sequential
],
nn
.
ModuleList
(
partitions
))
return
partitions
,
balance
MOVING_DENIED
=
TypeError
(
"denied to move parameters and buffers, because Pipe should manage device placement"
)
class
MultiProcessPipe
(
Module
):
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
to train on Pipe_. If the module requires lots of memory, Pipe will be
very efficient.
::
model = nn.Sequential(a, b, c, d)
model = Pipe(model, balance=[1, 1, 1, 1], chunks=8)
output = model(input)
.. _Pipe: https://arxiv.org/abs/1811.06965
Pipe combines pipeline parallelism with checkpointing to reduce peak
memory required to train while minimizing device under-utilization.
You should determine the balance when defining a :class:`Pipe` module, as
balancing will not be done automatically. The module will be partitioned
into multiple devices according to the given balance. You may rely on
heuristics to find your own optimal configuration.
Args:
module (torch.nn.Sequential):
sequential module to be parallelized
balance (ints):
list of number of layers in each partition
Keyword Args:
style (PipelineStyle):
whether to use a single process for all pipeline stages or to assign
one stage per process
group (ProcessGroup):
specific to `style=MultiProcess`, the process group that all
pipeline stages are a member of. Defaults to
`get_pipeline_parallel_group()`
worker_map (Dict[int, str]):
a map from worker name (the first argument to
`torch.distributed.rpc.init_rpc`) to global rank (i.e.
`torch.distributed.get_rank()`) needed in order for pipeline stages
to communicate with each other
input_device (device):
the device on which tensors should be located before being passed to
the first module in a given pipeline stage
chunks (int):
number of micro-batches (default: ``1``)
checkpoint (str):
when to enable checkpointing, one of ``'always'``,
``'except_last'``, or ``'never'`` (default: ``'except_last'``)
deferred_batch_norm (bool):
whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :class:`DeferredBatchNorm` for more
details)
pipelined_backward (bool, optional):
if True, call torch.autograd.backward once per microbatch on the
backward pass (instead of once for the whole batch). This works
around a potential deadlock in pytorch when using tensor parallelism
at the same time. Defaults to `True` if
`get_model_parallel_world_size() > 1`
(default: `None`)
retain_graph (bool):
The value passed to `torch.autograd.backwards(..., retain_graph=<value>)
(default: = `True`)
Raises:
TypeError:
the module is not a :class:`nn.Sequential <torch.nn.Sequential>`.
ValueError:
invalid arguments, or wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
MultiProcess
:
PipelineStyle
=
PipelineStyle
.
MultiProcess
AsyncSchedule
:
PipelineStyle
=
PipelineStyle
.
AsyncSchedule
#: The number of layers in each partition.
balance
:
List
[
int
]
=
[]
# ^^
# The default value [] required for Sphinx's autoattribute.
#: The devices mapped to each partition.
#:
#: ``devices[-1]`` refers to the device of the last partition, which means
#: it is the output device. Probably, you need to use it to transfer the
#: target to calculate the loss without a device mismatch
#: :exc:`RuntimeError`. For example::
#:
#: out_device = pipe.devices[-1]
#:
#: for input, target in loader:
#: target = target.to(out_device, non_blocking=True)
#: output = pipe(input)
#: loss = F.cross_entropy(output, target)
#:
#: The number of micro-batches.
chunks
:
int
=
1
#: The checkpoint mode to determine when to enable checkpointing. It is one
#: of ``'always'``, ``'except_last'``, or ``'never'``.
checkpoint
:
str
=
"except_last"
def
__init__
(
self
,
module
:
Union
[
nn
.
Sequential
,
ListOfLazyModules
],
balance
:
Optional
[
Iterable
[
int
]]
=
None
,
*
,
style
:
PipelineStyle
=
PipelineStyle
.
MultiProcess
,
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
chunks
:
int
=
chunks
,
checkpoint
:
str
=
checkpoint
,
deferred_batch_norm
:
bool
=
False
,
pipelined_backward
:
bool
=
None
,
retain_graph
:
bool
=
False
,
loss_fn
:
Optional
[
nn
.
Module
]
=
None
,
)
->
None
:
super
().
__init__
()
chunks
=
int
(
chunks
)
checkpoint
=
str
(
checkpoint
)
if
balance
is
None
:
raise
ValueError
(
recommend_auto_balance
(
"balance is required"
))
if
chunks
<=
0
:
raise
ValueError
(
"number of chunks must be positive integer"
)
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
raise
ValueError
(
"checkpoint is not one of 'always', 'except_last', or 'never'"
)
verify_module
(
module
)
# Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static.
if
isinstance
(
module
,
nn
.
Sequential
):
verify_skippables
(
module
)
self
.
chunks
=
chunks
self
.
checkpoint
=
checkpoint
self
.
pipelined_backward
=
pipelined_backward
self
.
retain_graph
=
retain_graph
self
.
pipeline
:
Optional
[
MultiProcessPipeline
]
self
.
loss_fn
=
loss_fn
self
.
lock
=
threading
.
Lock
()
self
.
group
=
group
self
.
worker_map
=
worker_map
self
.
input_device
=
input_device
# The micro-batch index where the checkpointing stops.
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
if
self
.
group
is
None
:
self
.
group
=
get_pipeline_parallel_group
()
assert
self
.
group
self
.
balance
=
list
(
balance
)
if
self
.
group
.
size
()
<
len
(
self
.
balance
):
raise
IndexError
(
f
"too few ranks to hold given partitions (ranks:
{
self
.
group
.
size
()
}
, partitions:"
f
"
{
len
(
self
.
balance
)
}
)"
)
try
:
rank
=
self
.
group
.
rank
()
if
rank
>=
len
(
self
.
balance
):
warnings
.
warn
(
"More ranks than partitions, some ranks unused"
)
self
.
partitions
:
List
[
ModuleWrapper
]
=
[]
else
:
self
.
partitions
=
instantiate_partition
(
module
,
balance
,
self
.
group
,
style
)
if
deferred_batch_norm
:
for
part
in
self
.
partitions
:
part
.
module
=
DeferredBatchNorm
.
convert_deferred_batch_norm
(
part
.
module
,
chunks
)
for
name
,
part
in
enumerate
(
self
.
partitions
):
self
.
add_module
(
str
(
name
),
part
.
module
)
if
isinstance
(
module
,
nn
.
Sequential
):
local_partitions
,
_
=
split_module
(
module
,
balance
)
self
.
_skip_layout
=
inspect_skip_layout
(
local_partitions
)
else
:
self
.
_skip_layout
=
SkipLayout
(
len
(
module
),
{})
# FIXME(tom)
except
BalanceError
as
exc
:
raise
ValueError
(
recommend_auto_balance
(
str
(
exc
)))
rank
=
self
.
group
.
rank
()
if
rank
>=
len
(
self
.
balance
):
self
.
pipeline
=
None
self
.
final_stage
=
False
else
:
self
.
final_stage
=
rank
==
len
(
self
.
balance
)
-
1
assert
loss_fn
is
None
or
self
.
final_stage
self
.
pipeline
=
MultiProcessPipeline
(
cast
(
List
[
nn
.
Sequential
],
self
.
partitions
),
self
.
_skip_layout
,
checkpoint_stop
,
style
=
style
,
group
=
self
.
group
,
worker_map
=
self
.
worker_map
,
input_device
=
self
.
input_device
,
final_stage
=
self
.
final_stage
,
)
del
module
if
self
.
pipelined_backward
is
None
:
if
get_model_parallel_world_size
()
>
1
:
self
.
pipelined_backward
=
True
else
:
self
.
pipelined_backward
=
False
def
__len__
(
self
)
->
int
:
"""Counts the length of the underlying sequential module."""
return
sum
(
len
(
p
)
for
p
in
self
.
partitions
)
def
__getitem__
(
self
,
index
:
int
)
->
nn
.
Module
:
"""Gets a layer in the underlying sequential module."""
partitions
:
List
[
Any
]
partitions
=
self
.
partitions
if
index
<
0
:
partitions
=
partitions
[::
-
1
]
for
partition
in
partitions
:
try
:
if
isinstance
(
partition
,
ModuleWrapper
):
return
partition
.
module
[
index
]
else
:
return
partition
[
index
]
except
IndexError
:
pass
shift
=
len
(
partition
)
if
index
<
0
:
index
+=
shift
else
:
index
-=
shift
raise
IndexError
def
__iter__
(
self
)
->
Iterable
[
nn
.
Module
]:
"""Iterates over children of the underlying sequential module."""
for
partition
in
self
.
partitions
:
yield
from
partition
.
module
def
forward
(
self
,
input
:
TensorOrTensors
,
*
,
event
=
None
)
->
TensorOrTensors
:
# type: ignore
""":class:`MultiProcessPipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
:class:`~torch.Tensor` or a tuple of tensors. This restriction is
applied at partition boundaries too.
Args:
input (torch.Tensor or tensors): input mini-batch
Returns:
tensor or tensors: output mini-batch
Raises:
TypeError: input is not a tensor or tensors.
"""
microbatch
.
check
(
input
)
if
not
self
.
group
:
# Empty sequential module is not illegal.
return
input
if
not
self
.
pipeline
:
# No pipeline is not illegal, more ranks than partitions
return
input
# Divide a mini-batch into micro-batches.
batches
=
microbatch
.
scatter
(
input
,
self
.
chunks
)
# Run pipeline parallelism.
with
self
.
lock
:
self
.
pipeline
.
run
(
self
.
training
,
batches
,
event
)
if
not
self
.
final_stage
:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
return
batches
# type: ignore
else
:
# Merge the micro-batches into one mini-batch.
if
self
.
pipelined_backward
:
with
torch
.
no_grad
():
output
=
microbatch
.
gather
(
batches
)
from
.phony
import
get_phony
phony
=
get_phony
(
torch
.
device
(
torch
.
cuda
.
current_device
()
if
torch
.
cuda
.
is_available
()
else
"cpu"
),
requires_grad
=
True
,
)
output
=
PipelinedBackwardPass
.
apply
(
output
,
batches
,
phony
,
True
)
# self.retain_graph)
else
:
output
=
microbatch
.
gather
(
batches
)
return
output
def
back_helper
(
self
,
output
:
List
[
microbatch
.
Batch
])
->
None
:
if
self
.
final_stage
:
raise
ValueError
(
"back_helper should only be called on non-final stages"
)
if
self
.
pipeline
:
self
.
pipeline
.
back_helper
(
list
(
reversed
(
output
)))
class
PipelinedBackwardPass
(
torch
.
autograd
.
Function
):
@
staticmethod
# type: ignore
def
forward
(
ctx
,
input
:
TensorOrTensors
,
batches
,
phony
,
retain_graph
)
->
TensorOrTensors
:
ctx
.
batches
=
batches
ctx
.
retain_graph
=
retain_graph
return
input
@
staticmethod
# type: ignore
def
backward
(
ctx
,
*
grads
)
->
Tuple
:
with
torch
.
no_grad
():
grad_batches
=
microbatch
.
scatter
(
grads
,
len
(
ctx
.
batches
))
for
grad
,
batch
in
reversed
(
list
(
zip
(
grad_batches
,
ctx
.
batches
))):
for
t
in
batch
:
t
.
retain_grad
()
torch
.
autograd
.
backward
(
batch
.
tensor_or_tensors
,
grad_tensors
=
(
*
grad
,),
retain_graph
=
ctx
.
retain_graph
)
with
torch
.
no_grad
():
if
ctx
.
batches
[
0
].
atomic
:
tensors
=
tuple
(
b
.
tensor
.
grad
for
b
in
ctx
.
batches
)
output
:
TensorOrTensors
=
torch
.
cat
(
tensors
)
else
:
rotated
=
[[
t
.
grad
for
t
in
b
.
tensors
]
for
b
in
ctx
.
batches
]
output_buf
=
[]
for
tensors
in
zip
(
*
rotated
):
output_buf
.
append
(
torch
.
cat
(
tensors
))
output
=
tuple
(
output_buf
)
del
ctx
.
batches
return
(
output
,
None
,
None
,
None
)
fairscale/nn/pipe/multiprocess_pipeline.py
0 → 100644
View file @
cae9b638
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The multiprocess pipeline parallelism of Pipe."""
import
logging
import
os
from
queue
import
Empty
as
QueueEmpty
from
queue
import
Queue
from
threading
import
Event
from
types
import
TracebackType
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
,
cast
import
torch
from
torch
import
Tensor
,
nn
from
torch.autograd.profiler
import
record_function
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_ranks
from
.async_schedule
import
AsyncEventLoop
,
ModuleWrapper
from
.checkpoint
import
Checkpointing
from
.messages
import
MakeTransport
,
Transport
from
.microbatch
import
Batch
from
.skip
import
Namespace
from
.skip.layout
import
SkipLayout
from
.skip.tracker
import
SkipTrackerThroughPotals
,
use_skip_tracker
from
.types
import
(
ACTIVATIONS_GRADS_QUEUE
,
PORTAL_QUEUE
,
SKIP_TENSOR_QUEUE
,
PipelineStyle
,
PipeMessage
,
TensorOrTensors
,
Tensors
,
)
from
.worker
import
Task
__all__
:
List
[
str
]
=
[]
ExcInfo
=
Tuple
[
Type
[
BaseException
],
BaseException
,
TracebackType
]
class
SendOperator
(
torch
.
autograd
.
Function
):
"""Send activations to the next pipeline stage"""
@
staticmethod
# type: ignore
def
forward
(
ctx
,
src_rank
,
dst_rank
,
transport
:
Transport
,
input
:
List
[
Tensor
],
index
:
int
)
->
Tensors
:
assert
src_rank
==
torch
.
distributed
.
get_rank
()
transport
.
send_message
(
PipeMessage
(
src_rank
,
dst_rank
,
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
index
,
tensors
=
tuple
(
input
)),
)
return
()
@
staticmethod
# type: ignore
def
backward
(
ctx
,
*
grad
:
Tensor
,)
->
Tensors
:
return
tuple
(
grad
)
class
RecvOperator
(
torch
.
autograd
.
Function
):
"""Receive activations to the previous pipeline stage"""
@
staticmethod
# type: ignore
def
forward
(
ctx
,
dst_rank
:
int
,
tensor
:
Tensor
,
input_device
,
transport
:
Transport
,
index
:
int
)
->
Tensors
:
assert
dst_rank
==
torch
.
distributed
.
get_rank
()
ctx
.
transport
=
transport
ctx
.
index
=
index
result
=
transport
.
get_out_of_order
(
ACTIVATIONS_GRADS_QUEUE
,
index
)
def
maybe_requires_grad
(
t
:
Tensor
)
->
Tensor
:
if
t
.
dtype
.
is_floating_point
:
return
t
.
requires_grad_
()
return
t
return
tuple
(
maybe_requires_grad
(
r
)
for
r
in
result
)
@
staticmethod
# type: ignore
def
backward
(
ctx
,
*
grad
:
Tensor
,)
->
Tuple
[
Optional
[
Tensor
],
...]:
ranks
=
get_pipeline_parallel_ranks
()
this_rank
=
torch
.
distributed
.
get_rank
()
ctx
.
transport
.
send_message
(
PipeMessage
(
this_rank
,
ranks
[
ranks
.
index
(
this_rank
)
-
1
],
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
ctx
.
index
,
tensors
=
tuple
(
grad
),
),
)
return
(
None
,
None
,
None
,
None
,
None
)
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if
TYPE_CHECKING
:
InQueue
=
Queue
[
Optional
[
"Task"
]]
OutQueue
=
Queue
[
Tuple
[
bool
,
Union
[
Tuple
[
"Task"
,
Batch
],
ExcInfo
,
None
]]]
else
:
InQueue
=
Queue
OutQueue
=
Queue
def
create_task
(
style
:
PipelineStyle
,
checkpoint_stop
:
int
,
i
:
int
,
j
:
int
,
batch
:
Batch
,
partition
:
nn
.
Sequential
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
)
->
Task
:
# Determine whether checkpointing or not.
if
i
<
checkpoint_stop
:
def
function
(
input
:
TensorOrTensors
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
i
],
chunk_id
:
int
=
i
,
part_id
:
int
=
j
,
)
->
TensorOrTensors
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
ret
=
partition
(
input
)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert
type
(
ret
)
is
not
list
,
"Only Tensor or Tuple of Tensor output is supported"
return
ret
chk
=
Checkpointing
(
function
,
batch
)
task
=
Task
(
None
,
compute
=
chk
.
checkpoint
,
finalize
=
chk
.
recompute
)
del
function
,
chk
# TODO(tom) maybe remove
else
:
def
compute
(
batch
:
Batch
=
batch
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
i
],
chunk_id
:
int
=
i
,
part_id
:
int
=
j
,
)
->
Batch
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
return
batch
.
call
(
partition
)
task
=
Task
(
None
,
compute
=
compute
,
finalize
=
None
)
del
compute
# TODO(tom) maybe remove
return
task
class
MultiProcessPipeline
:
"""The multiprocess pipeline parallelism for Pipe."""
def
__init__
(
self
,
partitions
:
List
[
nn
.
Sequential
],
skip_layout
:
SkipLayout
,
checkpoint_stop
:
int
,
style
:
PipelineStyle
,
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
final_stage
:
bool
=
False
,
)
->
None
:
self
.
partitions
:
List
[
ModuleWrapper
]
=
cast
(
List
[
ModuleWrapper
],
partitions
)
self
.
skip_layout
=
skip_layout
self
.
__checkpoint_stop
=
checkpoint_stop
self
.
style
=
style
self
.
group
=
group
self
.
training
:
bool
self
.
transport
=
MakeTransport
(
use_rpc
=
(
"OMPI_COMM_WORLD_RANK"
not
in
os
.
environ
)
or
(
"FORCE_RPC"
in
os
.
environ
),
worker_map
=
worker_map
,
input_device
=
input_device
,
)
self
.
input_device
=
input_device
self
.
all_at_once
=
False
self
.
callcount
=
0
self
.
final_stage
=
final_stage
@
property
def
checkpoint_stop
(
self
)
->
int
:
# Disable checkpointing if in eval mode.
training
=
self
.
partitions
[
0
].
module
.
training
if
not
training
:
return
0
return
self
.
__checkpoint_stop
def
run
(
self
,
training
:
bool
,
batches
:
List
[
Batch
],
event
:
Optional
[
Event
])
->
None
:
"""Runs pipeline parallelism.
It modifies the given batches in place.
"""
self
.
training
=
training
m
=
len
(
batches
)
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
len
(
batches
))]
if
self
.
style
is
PipelineStyle
.
MultiProcess
:
assert
self
.
group
schedule
=
[(
i
,
self
.
group
.
rank
())
for
i
in
range
(
m
)]
self
.
compute
(
batches
,
schedule
,
skip_trackers
)
elif
self
.
style
is
PipelineStyle
.
AsyncSchedule
:
assert
self
.
group
rank
=
self
.
group
.
rank
()
event_loop
=
AsyncEventLoop
(
self
.
partitions
,
self
.
group
,
self
.
transport
,
self
.
training
,
self
.
checkpoint_stop
,
)
if
rank
==
0
and
not
self
.
final_stage
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event head"
)
event_loop
.
event_loop_head
(
batches
,
skip_trackers
,
event
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event head"
)
elif
self
.
final_stage
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event tail"
)
event_loop
.
event_loop_tail
(
batches
,
skip_trackers
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event tail"
)
else
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event loop"
)
event_loop
.
event_loop
(
len
(
batches
),
skip_trackers
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event loop"
)
self
.
callcount
+=
1
def
get_batch_from_previous_stage
(
self
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
batches
:
List
[
Batch
]
)
->
Batch
:
phony
=
torch
.
empty
(
0
,
device
=
self
.
input_device
,
requires_grad
=
True
)
result
=
RecvOperator
.
apply
(
torch
.
distributed
.
get_rank
(),
phony
,
self
.
input_device
,
self
.
transport
,
i
)
if
len
(
result
)
==
1
:
batch
=
Batch
(
result
[
0
],
i
)
else
:
batch
=
Batch
(
result
,
i
)
self
.
recv_skip_tensors
(
skip_trackers
,
batches
)
return
batch
def
send_skip_tensors
(
self
,
this_rank
:
int
,
ranks
:
List
[
int
],
batch
:
Batch
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
]
)
->
None
:
assert
self
.
group
for
next_j
,
ns
,
name
in
self
.
skip_layout
.
copy_policy_by_src
(
self
.
group
.
rank
()):
life
=
skip_trackers
[
i
].
portals
[(
ns
,
name
)].
tensor_life
loaded
=
skip_trackers
[
i
].
load
(
batch
,
ns
,
name
)
if
loaded
is
not
None
:
tensors
=
tuple
([
loaded
])
else
:
tensors
=
tuple
()
self
.
transport
.
send_message
(
PipeMessage
(
this_rank
,
ranks
[
next_j
],
queue_name
=
SKIP_TENSOR_QUEUE
,
args
=
(
i
,
ns
,
name
,
life
),
tensors
=
tensors
,
),
sync
=
True
,
)
def
recv_skip_tensors
(
self
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
batches
:
List
[
Batch
])
->
None
:
while
True
:
try
:
message
=
self
.
transport
.
recv_message
(
SKIP_TENSOR_QUEUE
,
nowait
=
True
)
(
si
,
ns
,
name
,
life
)
=
message
.
args
value
:
Optional
[
TensorOrTensors
]
=
message
.
tensors
assert
isinstance
(
value
,
tuple
)
if
len
(
value
)
==
0
:
value
=
None
else
:
assert
len
(
value
)
==
1
value
=
value
[
0
]
skip_trackers
[
si
].
save
(
batches
[
si
],
ns
,
name
,
value
)
old_life
=
skip_trackers
[
si
].
portals
[(
ns
,
name
)].
tensor_life
if
life
!=
0
:
skip_trackers
[
si
].
portals
[(
ns
,
name
)].
tensor_life
=
life
except
QueueEmpty
:
break
def
execute_task
(
self
,
task
:
Task
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
])
->
Batch
:
batch
=
task
.
compute
()
assert
self
.
group
rank
=
self
.
group
.
rank
()
if
self
.
style
is
PipelineStyle
.
MultiProcess
and
not
self
.
final_stage
:
ranks
=
get_pipeline_parallel_ranks
()
this_rank
=
torch
.
distributed
.
get_rank
()
self
.
send_skip_tensors
(
this_rank
,
ranks
,
batch
,
i
,
skip_trackers
)
SendOperator
.
apply
(
this_rank
,
ranks
[
ranks
.
index
(
this_rank
)
+
1
],
self
.
transport
,
[
*
batch
],
i
)
for
portal
in
skip_trackers
[
i
].
portals
.
values
():
portal
.
pipeline
=
self
task
.
finalize
(
batch
)
return
batch
def
compute
(
self
,
batches
:
List
[
Batch
],
schedule
:
List
[
Tuple
[
int
,
int
]],
skip_trackers
:
List
[
SkipTrackerThroughPotals
]
)
->
None
:
"""Runs tasks with synchronization to copy streams."""
if
self
.
style
is
PipelineStyle
.
MultiProcess
:
assert
self
.
group
n
=
self
.
group
.
size
()
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for
i
,
j
in
schedule
:
batch
=
batches
[
i
]
if
self
.
style
is
PipelineStyle
.
MultiProcess
:
assert
len
(
self
.
partitions
)
==
1
partition
=
self
.
partitions
[
0
]
assert
self
.
group
if
self
.
group
.
rank
()
!=
0
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
task
=
create_task
(
self
.
style
,
self
.
checkpoint_stop
,
i
,
j
,
batch
,
partition
.
module
,
skip_trackers
)
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
def
send_portal_grad
(
self
,
ns_name
:
Tuple
[
Namespace
,
str
],
index
:
int
,
grad
:
TensorOrTensors
)
->
None
:
dest
,
src
=
self
.
skip_layout
.
by_ns_name
.
get
(
ns_name
,
(
-
1
,
-
1
))
if
dest
==
src
:
return
ranks
=
get_pipeline_parallel_ranks
()
dst_rank
=
ranks
[
dest
]
if
dst_rank
==
torch
.
distributed
.
get_rank
():
return
if
isinstance
(
grad
,
Tensor
):
grad
=
tuple
([
grad
])
self
.
transport
.
send_message
(
PipeMessage
(
ranks
[
src
],
dst_rank
,
queue_name
=
PORTAL_QUEUE
,
args
=
(
ns_name
,
index
),
tensors
=
grad
),
sync
=
True
,
)
def
recv_portal_grad
(
self
,
expected_ns_name
:
Tuple
[
Namespace
,
str
],
expected_index
:
int
)
->
Tensor
:
message
=
self
.
transport
.
recv_message
(
PORTAL_QUEUE
)
(
ns_name
,
index
)
=
message
.
args
grad
=
message
.
tensors
assert
len
(
grad
)
==
1
result
=
grad
[
0
]
assert
index
==
expected_index
and
ns_name
==
expected_ns_name
return
result
def
back_helper
(
self
,
output
:
List
[
Batch
])
->
None
:
if
self
.
style
==
PipelineStyle
.
AsyncSchedule
:
return
o
=
list
(
output
)
tensors
:
Tensors
if
self
.
all_at_once
:
# FIXME(tom) allow specifying this branch when constructing Pipe(), add a test
grads
=
[]
for
i
,
batch
in
enumerate
(
o
):
rank
=
torch
.
distributed
.
get_rank
()
found
=
self
.
transport
.
get_out_of_order
(
ACTIVATIONS_GRADS_QUEUE
,
i
)
assert
len
(
found
)
==
1
grads
.
append
(
found
[
0
])
tensors
=
tuple
(
x
.
tensor_or_tensors
for
x
in
o
)
# type: ignore
try
:
torch
.
autograd
.
backward
(
tensors
,
grad_tensors
=
grads
,
retain_graph
=
True
)
except
Exception
as
e
:
raise
RuntimeError
(
"Autograd failed"
)
from
e
else
:
rank
=
torch
.
distributed
.
get_rank
()
for
batch
in
o
:
found
=
self
.
transport
.
get_out_of_order
(
ACTIVATIONS_GRADS_QUEUE
,
batch
.
index
)
if
batch
.
atomic
:
tensors
=
tuple
([
batch
.
tensor
])
else
:
tensors
=
batch
.
tensors
if
len
(
found
)
!=
len
(
tensors
):
raise
RuntimeError
(
"different number of tensors and gradients"
)
grads
=
[]
final_tensors
=
[]
for
i
,
tensor
in
enumerate
(
tensors
):
if
tensor
.
requires_grad
or
getattr
(
tensor
,
"grad_fn"
,
None
)
is
not
None
:
grads
.
append
(
found
[
i
])
final_tensors
.
append
(
tensor
)
try
:
torch
.
autograd
.
backward
(
final_tensors
,
grad_tensors
=
grads
,
retain_graph
=
True
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Autograd failed on
{
torch
.
distributed
.
get_rank
()
}
"
)
from
e
fairscale/nn/pipe/pipe.py
View file @
cae9b638
...
@@ -19,29 +19,21 @@
...
@@ -19,29 +19,21 @@
"""The Pipe interface."""
"""The Pipe interface."""
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
import
itertools
import
threading
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
import
warnings
import
torch
import
torch
from
torch
import
Tensor
,
nn
from
torch
import
Tensor
,
nn
import
torch.autograd
import
torch.autograd
import
torch.cuda
import
torch.cuda
from
fairscale.nn.model_parallel
import
get_model_parallel_world_size
,
get_pipeline_parallel_group
from
.
import
microbatch
from
.
import
microbatch
from
.async_schedule
import
Invocation
,
Location
,
ModuleWrapper
from
.batchnorm
import
DeferredBatchNorm
from
.batchnorm
import
DeferredBatchNorm
from
.pipeline
import
Pipeline
from
.pipeline
import
Pipeline
from
.skip.layout
import
SkipLayout
,
inspect_skip_layout
from
.skip.layout
import
inspect_skip_layout
from
.skip.skippable
import
Skippable
,
verify_skippables
from
.skip.skippable
import
verify_skippables
from
.stream
import
AbstractStream
,
new_stream
from
.stream
import
AbstractStream
,
new_stream
from
.types
import
LazyModule
,
PipelineStyle
__all__
=
[
"Pipe"
,
"LazyModule"
]
__all__
=
[
"Pipe"
]
Device
=
Union
[
torch
.
device
,
int
,
str
]
Device
=
Union
[
torch
.
device
,
int
,
str
]
...
@@ -50,8 +42,6 @@ Devices = Union[Iterable[Device], List[Device]]
...
@@ -50,8 +42,6 @@ Devices = Union[Iterable[Device], List[Device]]
Tensors
=
Tuple
[
Tensor
,
...]
Tensors
=
Tuple
[
Tensor
,
...]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
ListOfLazyModules
=
List
[
LazyModule
]
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
Module
=
nn
.
Module
[
TensorOrTensors
]
Module
=
nn
.
Module
[
TensorOrTensors
]
NamedModules
=
OrderedDict
[
str
,
Module
]
NamedModules
=
OrderedDict
[
str
,
Module
]
...
@@ -79,34 +69,17 @@ naive automatic balancing:
...
@@ -79,34 +69,17 @@ naive automatic balancing:
"""
"""
# FIXME(tom) make this a valid way to call
def
verify_module
(
module
:
nn
.
Sequential
)
->
None
:
def
verify_list_of_callable
(
module
:
Union
[
nn
.
Sequential
,
list
])
->
None
:
if
not
isinstance
(
module
,
nn
.
Sequential
):
for
layer
in
module
:
raise
TypeError
(
"module must be nn.Sequential to be partitioned"
)
if
isinstance
(
layer
,
nn
.
Module
):
pass
elif
isinstance
(
layer
,
LazyModule
):
pass
else
:
raise
TypeError
(
f
"layer
{
type
(
layer
)
}
must be nn.Module or LazyModule to be partitioned"
)
def
verify_module
(
module
:
Union
[
nn
.
Sequential
,
ListOfLazyModules
])
->
None
:
named_children
=
list
(
module
.
named_children
())
if
isinstance
(
module
,
Iterable
)
and
not
isinstance
(
module
,
nn
.
Sequential
):
if
len
(
named_children
)
!=
len
(
module
):
verify_list_of_callable
(
module
)
raise
ValueError
(
"module with duplicate children is not supported"
)
else
:
if
not
isinstance
(
module
,
nn
.
Sequential
):
raise
TypeError
(
"module must be nn.Sequential to be partitioned"
)
named_children
=
list
(
module
.
named_children
())
if
len
(
named_children
)
!=
len
(
module
):
raise
ValueError
(
"module with duplicate children is not supported"
)
def
verify_splitting
(
def
verify_splitting
(
module
:
nn
.
Sequential
,
module
:
nn
.
Sequential
,
partitions
:
List
[
nn
.
Sequential
],
balance
:
Iterable
[
int
],
devices
:
List
[
torch
.
device
]
partitions
:
List
[
nn
.
Sequential
],
balance
:
Iterable
[
int
],
devices
:
Optional
[
List
[
torch
.
device
]],
)
->
None
:
)
->
None
:
num_parameters
=
len
(
list
(
module
.
parameters
()))
num_parameters
=
len
(
list
(
module
.
parameters
()))
num_child_parameters
=
sum
(
len
(
list
(
child
.
parameters
()))
for
child
in
module
.
children
())
num_child_parameters
=
sum
(
len
(
list
(
child
.
parameters
()))
for
child
in
module
.
children
())
...
@@ -117,7 +90,7 @@ def verify_splitting(
...
@@ -117,7 +90,7 @@ def verify_splitting(
for
j
in
range
(
i
+
1
,
len
(
partitions
)):
for
j
in
range
(
i
+
1
,
len
(
partitions
)):
parti
=
partitions
[
i
]
parti
=
partitions
[
i
]
partj
=
partitions
[
j
]
partj
=
partitions
[
j
]
if
devices
and
devices
[
i
]
==
devices
[
j
]:
if
devices
[
i
]
==
devices
[
j
]:
continue
continue
for
p
in
parti
.
parameters
():
for
p
in
parti
.
parameters
():
for
q
in
partj
.
parameters
():
for
q
in
partj
.
parameters
():
...
@@ -129,159 +102,9 @@ class BalanceError(ValueError):
...
@@ -129,159 +102,9 @@ class BalanceError(ValueError):
pass
pass
def
check_balance
(
module
:
Any
,
balance
:
Iterable
[
int
],
filter_unique
:
bool
=
False
)
->
None
:
if
filter_unique
:
module_len
=
len
(
set
(
map
(
id
,
module
)))
else
:
module_len
=
len
(
module
)
if
module_len
!=
sum
(
balance
):
raise
BalanceError
(
f
"module and sum of balance have different length (module:
{
len
(
module
)
}
, sum of balance:
{
sum
(
balance
)
}
)"
)
if
any
(
x
<=
0
for
x
in
balance
):
raise
BalanceError
(
f
"all balance numbers must be positive integer (balance:
{
balance
}
)"
)
@
dataclass
class
PartitionInfo
:
location
:
Location
modules
:
"OrderedDict[str, nn.Module]"
invocations
:
List
[
Invocation
]
=
field
(
default_factory
=
list
)
def
__len__
(
self
)
->
int
:
return
len
(
self
.
modules
)
def
instantiate_partition
(
module
:
Union
[
nn
.
Sequential
,
ListOfLazyModules
],
balance
:
Iterable
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
style
:
PipelineStyle
,
)
->
List
[
ModuleWrapper
]:
balance
=
list
(
balance
)
check_balance
(
module
,
balance
,
True
)
layers
:
NamedModules
=
OrderedDict
()
def
maybe_realize
(
layer
:
Any
)
->
nn
.
Module
:
if
isinstance
(
layer
,
nn
.
Module
):
return
layer
elif
callable
(
layer
):
return
layer
()
else
:
raise
TypeError
(
f
"layer must be nn.Module or callable, is
{
type
(
layer
)
}
"
)
def
iterate_module
(
module
:
Union
[
nn
.
Sequential
,
list
])
->
Iterable
[
Tuple
[
Any
,
nn
.
Module
]]:
if
isinstance
(
module
,
nn
.
Sequential
):
yield
from
module
.
named_children
()
else
:
yield
from
((
str
(
k
),
v
)
for
k
,
v
in
enumerate
(
module
))
if
style
==
PipelineStyle
.
AsyncSchedule
:
module_ids
=
list
(
map
(
id
,
module
))
index_of_first_use
=
[
module_ids
.
index
(
x
)
for
x
in
module_ids
]
locations
:
List
[
Location
]
=
[]
module_iter
=
enumerate
(
iterate_module
(
module
))
partitions
:
List
[
List
[
PartitionInfo
]]
=
[]
for
bi
,
b
in
enumerate
(
balance
):
modules_for_rank
:
List
[
PartitionInfo
]
=
[]
current_module
:
OrderedDict
[
str
,
nn
.
Module
]
=
OrderedDict
()
def
current_location
()
->
Location
:
return
Location
(
bi
,
len
(
modules_for_rank
))
def
append_module
(
mod
:
"OrderedDict[str, nn.Module]"
)
->
None
:
modules_for_rank
.
append
(
PartitionInfo
(
current_location
(),
mod
))
while
sum
(
map
(
len
,
modules_for_rank
))
+
len
(
current_module
)
<
b
:
module_index
,
(
name
,
layer
)
=
next
(
module_iter
)
if
index_of_first_use
[
module_index
]
!=
module_index
:
# Subsequent reuse of a module
locations
.
append
(
locations
[
index_of_first_use
[
module_index
]])
continue
is_reused
=
index_of_first_use
.
count
(
index_of_first_use
[
module_index
])
>
1
if
is_reused
and
len
(
current_module
)
>
0
:
append_module
(
current_module
)
current_module
=
OrderedDict
()
current_module
[
str
(
name
)]
=
layer
locations
.
append
(
current_location
())
if
is_reused
:
append_module
(
current_module
)
current_module
=
OrderedDict
()
if
len
(
current_module
)
>
0
:
append_module
(
current_module
)
partitions
.
append
(
modules_for_rank
)
filtered_locations
:
List
[
Optional
[
Location
]]
=
[
loc
for
loc
,
_
in
itertools
.
groupby
(
locations
)]
filtered_locations
.
append
(
None
)
for
i
in
range
(
len
(
filtered_locations
)
-
1
):
loc
=
filtered_locations
[
i
]
assert
loc
if
i
==
0
:
inv
=
Invocation
(
i
,
loc
,
None
,
filtered_locations
[
i
+
1
])
else
:
inv
=
Invocation
(
i
,
loc
,
filtered_locations
[
i
-
1
],
filtered_locations
[
i
+
1
])
partitions
[
loc
.
stage
][
loc
.
index
].
invocations
.
append
(
inv
)
invocations
=
enumerate
(
iterate_module
(
module
))
partition
=
partitions
[
group
.
rank
()]
result
:
List
[
ModuleWrapper
]
=
[]
for
partition_info
in
partition
:
wrapper
=
ModuleWrapper
(
nn
.
Sequential
(
OrderedDict
((
k
,
maybe_realize
(
m
))
for
k
,
m
in
partition_info
.
modules
.
items
())),
partition_info
.
location
,
partition_info
.
invocations
,
)
if
not
isinstance
(
module
,
nn
.
Sequential
):
for
layer
in
wrapper
.
module
:
if
isinstance
(
layer
,
Skippable
):
raise
ValueError
(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
result
.
append
(
wrapper
)
return
result
j
=
0
for
name
,
layer
in
iterate_module
(
module
):
layers
[
name
]
=
layer
if
len
(
layers
)
==
balance
[
j
]:
if
j
==
group
.
rank
():
for
key
in
layers
:
layers
[
key
]
=
maybe_realize
(
layers
[
key
])
if
not
isinstance
(
module
,
nn
.
Sequential
):
for
layer
in
layers
.
values
():
if
isinstance
(
layer
,
Skippable
):
raise
ValueError
(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
return
[
ModuleWrapper
(
nn
.
Sequential
(
layers
),
Location
(
j
,
0
))]
# Prepare for the next partition.
layers
.
clear
()
j
+=
1
raise
ValueError
(
"Souldn't get here, more ranks than partitions"
)
def
split_module
(
def
split_module
(
module
:
nn
.
Sequential
,
balance
:
Iterable
[
int
],
devices
:
Optional
[
List
[
torch
.
device
]
]
,
module
:
nn
.
Sequential
,
balance
:
Iterable
[
int
],
devices
:
List
[
torch
.
device
],
)
->
Tuple
[
List
[
nn
.
Sequential
],
List
[
int
],
Optional
[
List
[
torch
.
device
]]
]
:
)
->
Tuple
[
List
[
nn
.
Sequential
],
List
[
int
],
List
[
torch
.
device
]]:
"""Splits a module into multiple partitions.
"""Splits a module into multiple partitions.
Returns:
Returns:
...
@@ -300,11 +123,18 @@ def split_module(
...
@@ -300,11 +123,18 @@ def split_module(
"""
"""
balance
=
list
(
balance
)
balance
=
list
(
balance
)
check_balance
(
module
,
balance
)
if
len
(
module
)
!=
sum
(
balance
):
raise
BalanceError
(
"module and sum of balance have different length "
f
"(module:
{
len
(
module
)
}
, sum of balance:
{
sum
(
balance
)
}
)"
)
if
devices
and
len
(
balance
)
>
len
(
devices
):
if
any
(
x
<=
0
for
x
in
balance
):
raise
BalanceError
(
f
"all balance numbers must be positive integer (balance:
{
balance
}
)"
)
if
len
(
balance
)
>
len
(
devices
):
raise
IndexError
(
raise
IndexError
(
f
"too few devices to hold given partitions (devices:
{
len
(
devices
)
}
, partitions:
{
len
(
balance
)
}
)"
"too few devices to hold given partitions
"
f
"
(devices:
{
len
(
devices
)
}
, partitions:
{
len
(
balance
)
}
)"
)
)
j
=
0
j
=
0
...
@@ -318,9 +148,8 @@ def split_module(
...
@@ -318,9 +148,8 @@ def split_module(
# Group buffered layers as a partition.
# Group buffered layers as a partition.
partition
=
nn
.
Sequential
(
layers
)
partition
=
nn
.
Sequential
(
layers
)
if
devices
:
device
=
devices
[
j
]
device
=
devices
[
j
]
partition
.
to
(
device
)
partition
.
to
(
device
)
partitions
.
append
(
partition
)
partitions
.
append
(
partition
)
...
@@ -329,13 +158,12 @@ def split_module(
...
@@ -329,13 +158,12 @@ def split_module(
j
+=
1
j
+=
1
partitions
=
cast
(
List
[
nn
.
Sequential
],
nn
.
ModuleList
(
partitions
))
partitions
=
cast
(
List
[
nn
.
Sequential
],
nn
.
ModuleList
(
partitions
))
if
devices
:
del
devices
[
j
:]
del
devices
[
j
:]
return
partitions
,
balance
,
devices
return
partitions
,
balance
,
devices
MOVING_DENIED
=
TypeError
(
"denied to move parameters and buffers, because Pipe should manage device placement"
)
MOVING_DENIED
=
TypeError
(
"denied to move parameters and buffers,
"
"
because Pipe should manage device placement"
)
class
Pipe
(
Module
):
class
Pipe
(
Module
):
...
@@ -365,23 +193,8 @@ class Pipe(Module):
...
@@ -365,23 +193,8 @@ class Pipe(Module):
list of number of layers in each partition
list of number of layers in each partition
Keyword Args:
Keyword Args:
style (PipelineStyle):
whether to use a single process for all pipeline stages or to assign
one stage per process
devices (iterable of devices):
devices (iterable of devices):
devices to use (default: all CUDA devices)
devices to use (default: all CUDA devices)
group (ProcessGroup):
specific to `style=MultiProcess`, the process group that all
pipeline stages are a member of. Defaults to
`get_pipeline_parallel_group()`
worker_map (Dict[int, str]):
a map from worker name (the first argument to
`torch.distributed.rpc.init_rpc`) to global rank (i.e.
`torch.distributed.get_rank()`) needed in order for pipeline stages
to communicate with each other
input_device (device):
the device on which tensors should be located before being passed to
the first module in a given pipeline stage
chunks (int):
chunks (int):
number of micro-batches (default: ``1``)
number of micro-batches (default: ``1``)
checkpoint (str):
checkpoint (str):
...
@@ -389,18 +202,8 @@ class Pipe(Module):
...
@@ -389,18 +202,8 @@ class Pipe(Module):
``'except_last'``, or ``'never'`` (default: ``'except_last'``)
``'except_last'``, or ``'never'`` (default: ``'except_last'``)
deferred_batch_norm (bool):
deferred_batch_norm (bool):
whether to use deferred BatchNorm moving statistics (default:
whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :
class
:`DeferredBatchNorm` for more
:data:`False`, see :
ref
:`Deferred
Batch
Norm
alization
` for more
details)
details)
pipelined_backward (bool, optional):
if True, call torch.autograd.backward once per microbatch on the
backward pass (instead of once for the whole batch). This works
around a potential deadlock in pytorch when using tensor parallelism
at the same time. Defaults to `True` if
`get_model_parallel_world_size() > 1`
(default: `None`)
retain_graph (bool):
The value passed to `torch.autograd.backwards(..., retain_graph=<value>)
(default: = `True`)
Raises:
Raises:
TypeError:
TypeError:
...
@@ -412,10 +215,6 @@ class Pipe(Module):
...
@@ -412,10 +215,6 @@ class Pipe(Module):
"""
"""
SingleProcess
:
PipelineStyle
=
PipelineStyle
.
SingleProcess
MultiProcess
:
PipelineStyle
=
PipelineStyle
.
MultiProcess
AsyncSchedule
:
PipelineStyle
=
PipelineStyle
.
AsyncSchedule
#: The number of layers in each partition.
#: The number of layers in each partition.
balance
:
List
[
int
]
=
[]
balance
:
List
[
int
]
=
[]
# ^^
# ^^
...
@@ -435,7 +234,7 @@ class Pipe(Module):
...
@@ -435,7 +234,7 @@ class Pipe(Module):
#: output = pipe(input)
#: output = pipe(input)
#: loss = F.cross_entropy(output, target)
#: loss = F.cross_entropy(output, target)
#:
#:
devices
:
Optional
[
List
[
torch
.
device
]
]
=
None
devices
:
List
[
torch
.
device
]
=
[]
#: The number of micro-batches.
#: The number of micro-batches.
chunks
:
int
=
1
chunks
:
int
=
1
...
@@ -446,20 +245,13 @@ class Pipe(Module):
...
@@ -446,20 +245,13 @@ class Pipe(Module):
def
__init__
(
def
__init__
(
self
,
self
,
module
:
Union
[
nn
.
Sequential
,
ListOfLazyModules
],
module
:
nn
.
Sequential
,
balance
:
Optional
[
Iterable
[
int
]]
=
None
,
balance
:
Optional
[
Iterable
[
int
]]
=
None
,
*
,
*
,
style
:
PipelineStyle
=
PipelineStyle
.
SingleProcess
,
devices
:
Optional
[
Devices
]
=
None
,
devices
:
Optional
[
Devices
]
=
None
,
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
chunks
:
int
=
chunks
,
chunks
:
int
=
chunks
,
checkpoint
:
str
=
checkpoint
,
checkpoint
:
str
=
checkpoint
,
deferred_batch_norm
:
bool
=
False
,
deferred_batch_norm
:
bool
=
False
,
pipelined_backward
:
bool
=
None
,
retain_graph
:
bool
=
False
,
loss_fn
:
Optional
[
nn
.
Module
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -477,145 +269,50 @@ class Pipe(Module):
...
@@ -477,145 +269,50 @@ class Pipe(Module):
# Verify if the underlying skippable modules satisfy integrity. The
# Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static.
# integrity can be verified before forward() because it is static.
if
isinstance
(
module
,
nn
.
Sequential
):
verify_skippables
(
module
)
verify_skippables
(
module
)
self
.
chunks
=
chunks
self
.
chunks
=
chunks
self
.
checkpoint
=
checkpoint
self
.
checkpoint
=
checkpoint
self
.
pipelined_backward
=
pipelined_backward
self
.
retain_graph
=
retain_graph
self
.
pipeline
:
Optional
[
Pipeline
]
self
.
loss_fn
=
loss_fn
self
.
lock
=
threading
.
Lock
()
self
.
group
=
group
self
.
worker_map
=
worker_map
self
.
input_device
=
input_device
self
.
_copy_streams
:
List
[
List
[
AbstractStream
]]
=
[]
# The micro-batch index where the checkpointing stops.
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
if
style
is
PipelineStyle
.
SingleProcess
:
module
=
cast
(
nn
.
Sequential
,
module
)
if
deferred_batch_norm
:
module
=
DeferredBatchNorm
.
convert_deferred_batch_norm
(
module
,
chunks
)
if
input_device
is
not
None
:
if
deferred_batch_norm
:
raise
ValueError
(
"'input_device' argument only applies to 'PipelineStyle.MultiProcess'"
)
module
=
DeferredBatchNorm
.
convert_deferred_batch_norm
(
module
,
chunks
)
if
devices
is
None
:
if
devices
is
None
:
devices
=
range
(
torch
.
cuda
.
device_count
())
devices
=
range
(
torch
.
cuda
.
device_count
())
devices
=
[
torch
.
device
(
d
)
for
d
in
devices
]
devices
=
cast
(
List
[
torch
.
device
],
devices
)
devices
=
[
torch
.
device
(
d
)
for
d
in
devices
]
try
:
devices
=
cast
(
List
[
torch
.
device
],
devices
)
self
.
partitions
,
self
.
balance
,
self
.
devices
=
split_module
(
module
,
balance
,
devices
)
except
BalanceError
as
exc
:
raise
ValueError
(
recommend_auto_balance
(
str
(
exc
)))
try
:
verify_splitting
(
module
,
self
.
partitions
,
self
.
balance
,
self
.
devices
)
self
.
partitions
,
self
.
balance
,
self
.
devices
=
split_module
(
module
,
balance
,
devices
)
except
BalanceError
as
exc
:
raise
ValueError
(
recommend_auto_balance
(
str
(
exc
)))
verify_splitting
(
module
,
self
.
partitions
,
self
.
balance
,
self
.
devices
)
self
.
_skip_layout
=
inspect_skip_layout
(
self
.
partitions
)
# Separate CUDA streams for copy.
copy_streams
=
self
.
_ensure_copy_streams
()
if
self
.
pipelined_backward
is
None
:
self
.
pipelined_backward
=
False
self
.
pipeline
=
Pipeline
(
self
.
partitions
,
self
.
devices
,
copy_streams
,
self
.
_skip_layout
,
checkpoint_stop
,
style
=
style
,
)
elif
style
in
[
PipelineStyle
.
MultiProcess
,
PipelineStyle
.
AsyncSchedule
]:
self
.
_copy_streams
:
List
[
List
[
AbstractStream
]]
=
[]
self
.
_skip_layout
=
inspect_skip_layout
(
self
.
partitions
)
if
self
.
group
is
None
:
self
.
group
=
get_pipeline_parallel_group
()
assert
self
.
group
if
devices
is
not
None
:
# Separate CUDA streams for copy.
raise
ValueError
(
"'devices' argument only applies to 'PipelineStyle.SingleProcess'"
)
copy_streams
=
self
.
_ensure_copy_streams
(
)
self
.
balance
=
list
(
balance
)
# The micro-batch index where the checkpointing stops.
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
if
self
.
group
.
size
()
<
len
(
self
.
balance
):
self
.
pipeline
=
Pipeline
(
self
.
partitions
,
self
.
devices
,
copy_streams
,
self
.
_skip_layout
,
checkpoint_stop
)
raise
IndexError
(
f
"too few ranks to hold given partitions (ranks:
{
self
.
group
.
size
()
}
, partitions:"
f
"
{
len
(
self
.
balance
)
}
)"
)
try
:
rank
=
self
.
group
.
rank
()
if
rank
>=
len
(
self
.
balance
):
warnings
.
warn
(
"More ranks than partitions, some ranks unused"
)
self
.
mp_partitions
:
List
[
ModuleWrapper
]
=
[]
else
:
self
.
mp_partitions
=
instantiate_partition
(
module
,
balance
,
self
.
group
,
style
)
if
deferred_batch_norm
:
for
part
in
self
.
mp_partitions
:
part
.
module
=
DeferredBatchNorm
.
convert_deferred_batch_norm
(
part
.
module
,
chunks
)
for
name
,
part
in
enumerate
(
self
.
mp_partitions
):
self
.
add_module
(
str
(
name
),
part
.
module
)
self
.
devices
=
None
if
isinstance
(
module
,
nn
.
Sequential
):
local_partitions
,
_
,
_
=
split_module
(
module
,
balance
,
None
)
self
.
_skip_layout
=
inspect_skip_layout
(
local_partitions
)
else
:
self
.
_skip_layout
=
SkipLayout
(
len
(
module
),
{})
# FIXME(tom)
except
BalanceError
as
exc
:
raise
ValueError
(
recommend_auto_balance
(
str
(
exc
)))
rank
=
self
.
group
.
rank
()
if
rank
>=
len
(
self
.
balance
):
self
.
pipeline
=
None
self
.
final_stage
=
False
else
:
self
.
final_stage
=
rank
==
len
(
self
.
balance
)
-
1
assert
loss_fn
is
None
or
self
.
final_stage
self
.
pipeline
=
Pipeline
(
cast
(
List
[
nn
.
Sequential
],
self
.
mp_partitions
),
None
,
None
,
self
.
_skip_layout
,
checkpoint_stop
,
style
=
style
,
group
=
self
.
group
,
worker_map
=
self
.
worker_map
,
input_device
=
self
.
input_device
,
final_stage
=
self
.
final_stage
,
)
del
module
if
self
.
pipelined_backward
is
None
:
if
get_model_parallel_world_size
()
>
1
:
self
.
pipelined_backward
=
True
else
:
self
.
pipelined_backward
=
False
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
"""Counts the length of the underlying sequential module."""
"""Counts the length of the underlying sequential module."""
if
hasattr
(
self
,
"partitions"
):
return
sum
(
len
(
p
)
for
p
in
self
.
partitions
)
return
sum
(
len
(
p
)
for
p
in
self
.
partitions
)
else
:
return
sum
(
len
(
p
)
for
p
in
self
.
mp_partitions
)
def
__getitem__
(
self
,
index
:
int
)
->
nn
.
Module
:
def
__getitem__
(
self
,
index
:
int
)
->
nn
.
Module
:
"""Gets a layer in the underlying sequential module."""
"""Gets a layer in the underlying sequential module."""
partitions
:
List
[
Any
]
partitions
=
self
.
partitions
if
hasattr
(
self
,
"partitions"
):
partitions
=
self
.
partitions
else
:
partitions
=
self
.
mp_partitions
if
index
<
0
:
if
index
<
0
:
partitions
=
partitions
[::
-
1
]
partitions
=
partitions
[::
-
1
]
for
partition
in
partitions
:
for
partition
in
partitions
:
try
:
try
:
if
isinstance
(
partition
,
ModuleWrapper
):
return
partition
[
index
]
return
partition
.
module
[
index
]
else
:
return
partition
[
index
]
except
IndexError
:
except
IndexError
:
pass
pass
...
@@ -630,47 +327,35 @@ class Pipe(Module):
...
@@ -630,47 +327,35 @@ class Pipe(Module):
def
__iter__
(
self
)
->
Iterable
[
nn
.
Module
]:
def
__iter__
(
self
)
->
Iterable
[
nn
.
Module
]:
"""Iterates over children of the underlying sequential module."""
"""Iterates over children of the underlying sequential module."""
if
hasattr
(
self
,
"partitions"
):
for
partition
in
self
.
partitions
:
for
partition
in
self
.
partitions
:
yield
from
partition
yield
from
partition
else
:
for
mp_partition
in
self
.
mp_partitions
:
yield
from
mp_partition
.
module
# Pipe should manage the device of each partition.
# Pipe should manage the device of each partition.
# Deny cuda(), cpu(), and to() with device, by TypeError.
# Deny cuda(), cpu(), and to() with device, by TypeError.
def
cuda
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
"Pipe"
:
def
cuda
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
"Pipe"
:
if
self
.
devices
:
raise
MOVING_DENIED
raise
MOVING_DENIED
if
device
:
return
super
().
cuda
(
device
=
device
)
else
:
return
super
().
cuda
()
def
cpu
(
self
)
->
"Pipe"
:
def
cpu
(
self
)
->
"Pipe"
:
if
self
.
devices
:
raise
MOVING_DENIED
raise
MOVING_DENIED
return
super
().
cpu
()
def
to
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
"Pipe"
:
def
to
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
"Pipe"
:
"""Restrict .to() options.
# Deny these usages:
#
Deny these usages:
# - to(device[, dtype, non_blocking])
- to(device[, dtype, non_blocking])
# - to(tensor[, non_blocking])
- to(tensor[, non_blocking])
#
# But allow this:
#
# - to(dtype[, non_blocking])
#
if
"device"
in
kwargs
or
"tensor"
in
kwargs
:
raise
MOVING_DENIED
But allow this:
if
args
:
- to(dtype[, non_blocking])
if
isinstance
(
args
[
0
],
(
torch
.
device
,
int
,
str
)):
"""
raise
MOVING_DENIED
if
self
.
devices
:
if
torch
.
is_tensor
(
args
[
0
]):
if
"device"
in
kwargs
or
"tensor"
in
kwargs
:
raise
MOVING_DENIED
raise
MOVING_DENIED
if
args
:
if
isinstance
(
args
[
0
],
(
torch
.
device
,
int
,
str
)):
raise
MOVING_DENIED
if
torch
.
is_tensor
(
args
[
0
]):
raise
MOVING_DENIED
return
super
().
to
(
*
args
,
**
kwargs
)
return
super
().
to
(
*
args
,
**
kwargs
)
...
@@ -683,13 +368,12 @@ class Pipe(Module):
...
@@ -683,13 +368,12 @@ class Pipe(Module):
"""
"""
if
not
self
.
_copy_streams
:
if
not
self
.
_copy_streams
:
assert
self
.
devices
is
not
None
for
device
in
self
.
devices
:
for
device
in
self
.
devices
:
self
.
_copy_streams
.
append
([
new_stream
(
device
)
for
_
in
range
(
self
.
chunks
)])
self
.
_copy_streams
.
append
([
new_stream
(
device
)
for
_
in
range
(
self
.
chunks
)])
return
self
.
_copy_streams
return
self
.
_copy_streams
def
forward
(
self
,
input
:
TensorOrTensors
,
*
,
event
=
None
)
->
TensorOrTensors
:
# type: ignore
def
forward
(
self
,
input
:
TensorOrTensors
)
->
TensorOrTensors
:
# type: ignore
""":class:`Pipe` is a fairly transparent module wrapper. It doesn't
""":class:`Pipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
there's type restriction. Input and output have to be a
...
@@ -708,82 +392,16 @@ class Pipe(Module):
...
@@ -708,82 +392,16 @@ class Pipe(Module):
"""
"""
microbatch
.
check
(
input
)
microbatch
.
check
(
input
)
if
not
self
.
group
and
not
self
.
devices
:
if
not
self
.
devices
:
# Empty sequential module is not illegal.
# Empty sequential module is not illegal.
return
input
return
input
if
not
self
.
pipeline
:
# No pipeline is not illegal, more ranks than partitions
return
input
# Divide a mini-batch into micro-batches.
# Divide a mini-batch into micro-batches.
batches
=
microbatch
.
scatter
(
input
,
self
.
chunks
)
batches
=
microbatch
.
scatter
(
input
,
self
.
chunks
)
# Run pipeline parallelism.
# Run pipeline parallelism.
with
self
.
lock
:
self
.
pipeline
.
run
(
batches
)
self
.
pipeline
.
run
(
self
.
training
,
batches
,
event
)
if
self
.
group
and
not
self
.
final_stage
:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
return
batches
# type: ignore
else
:
# Merge the micro-batches into one mini-batch.
if
self
.
pipelined_backward
:
with
torch
.
no_grad
():
output
=
microbatch
.
gather
(
batches
)
from
.phony
import
get_phony
phony
=
get_phony
(
torch
.
device
(
torch
.
cuda
.
current_device
()
if
torch
.
cuda
.
is_available
()
else
"cpu"
),
requires_grad
=
True
,
)
output
=
PipelinedBackwardPass
.
apply
(
output
,
batches
,
phony
,
True
)
# self.retain_graph)
else
:
output
=
microbatch
.
gather
(
batches
)
return
output
def
back_helper
(
self
,
output
:
List
[
microbatch
.
Batch
])
->
None
:
if
self
.
final_stage
:
raise
ValueError
(
"back_helper should only be called on non-final stages"
)
if
self
.
pipeline
:
self
.
pipeline
.
back_helper
(
list
(
reversed
(
output
)))
class
PipelinedBackwardPass
(
torch
.
autograd
.
Function
):
@
staticmethod
# type: ignore
def
forward
(
ctx
,
input
:
TensorOrTensors
,
batches
,
phony
,
retain_graph
)
->
TensorOrTensors
:
ctx
.
batches
=
batches
ctx
.
retain_graph
=
retain_graph
return
input
@
staticmethod
# type: ignore
def
backward
(
ctx
,
*
grads
)
->
Tuple
:
with
torch
.
no_grad
():
grad_batches
=
microbatch
.
scatter
(
grads
,
len
(
ctx
.
batches
))
for
grad
,
batch
in
reversed
(
list
(
zip
(
grad_batches
,
ctx
.
batches
))):
for
t
in
batch
:
t
.
retain_grad
()
torch
.
autograd
.
backward
(
batch
.
tensor_or_tensors
,
grad_tensors
=
(
*
grad
,),
retain_graph
=
ctx
.
retain_graph
)
with
torch
.
no_grad
():
if
ctx
.
batches
[
0
].
atomic
:
tensors
=
tuple
(
b
.
tensor
.
grad
for
b
in
ctx
.
batches
)
output
:
TensorOrTensors
=
torch
.
cat
(
tensors
)
else
:
rotated
=
[[
t
.
grad
for
t
in
b
.
tensors
]
for
b
in
ctx
.
batches
]
output_buf
=
[]
for
tensors
in
zip
(
*
rotated
):
output_buf
.
append
(
torch
.
cat
(
tensors
))
output
=
tuple
(
output_buf
)
del
ctx
.
batches
return
(
output
,
None
,
None
,
None
)
# Merge the micro-batches into one mini-batch.
output
=
microbatch
.
gather
(
batches
)
return
output
fairscale/nn/pipe/pipeline.py
View file @
cae9b638
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
#
# This source code is licensed under the BSD license found in the
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
...
@@ -17,101 +18,30 @@
...
@@ -17,101 +18,30 @@
# limitations under the License.
# limitations under the License.
"""The pipeline parallelism of Pipe."""
"""The pipeline parallelism of Pipe."""
import
logging
import
os
from
queue
import
Empty
as
QueueEmpty
from
queue
import
Queue
from
queue
import
Queue
from
threading
import
Event
from
types
import
TracebackType
from
types
import
TracebackType
from
typing
import
TYPE_CHECKING
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Type
,
Union
,
cast
from
typing
import
TYPE_CHECKING
,
Iterable
,
List
,
Optional
,
Tuple
,
Type
,
Union
,
cast
import
torch
import
torch
from
torch
import
Tensor
,
nn
from
torch
import
Tensor
,
nn
from
torch.autograd.profiler
import
record_function
from
torch.autograd.profiler
import
record_function
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_ranks
from
.async_schedule
import
AsyncEventLoop
,
ModuleWrapper
from
.checkpoint
import
Checkpointing
from
.checkpoint
import
Checkpointing
from
.copy
import
Copy
,
Wait
from
.copy
import
Copy
,
Wait
from
.dependency
import
fork
,
join
from
.dependency
import
fork
,
join
from
.messages
import
MakeTransport
,
Transport
from
.microbatch
import
Batch
from
.microbatch
import
Batch
from
.skip
import
Namespace
from
.skip.layout
import
SkipLayout
from
.skip.layout
import
SkipLayout
from
.skip.tracker
import
SkipTrackerThroughPotals
,
use_skip_tracker
from
.skip.tracker
import
SkipTrackerThroughPotals
,
use_skip_tracker
from
.stream
import
AbstractStream
,
current_stream
,
use_device
from
.stream
import
AbstractStream
,
current_stream
,
use_device
from
.types
import
(
ACTIVATIONS_GRADS_QUEUE
,
PORTAL_QUEUE
,
SKIP_TENSOR_QUEUE
,
PipelineStyle
,
PipeMessage
,
Schedule
,
TensorOrTensors
,
Tensors
,
)
from
.worker
import
Task
,
create_workers
,
join_workers
from
.worker
import
Task
,
create_workers
,
join_workers
__all__
:
List
[
str
]
=
[]
__all__
:
List
[
str
]
=
[]
ExcInfo
=
Tuple
[
Type
[
BaseException
],
BaseException
,
TracebackType
]
class
SendOperator
(
torch
.
autograd
.
Function
):
"""Send activations to the next pipeline stage"""
@
staticmethod
# type: ignore
def
forward
(
ctx
,
src_rank
,
dst_rank
,
transport
:
Transport
,
input
:
List
[
Tensor
],
index
:
int
)
->
Tensors
:
assert
src_rank
==
torch
.
distributed
.
get_rank
()
transport
.
send_message
(
PipeMessage
(
src_rank
,
dst_rank
,
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
index
,
tensors
=
tuple
(
input
)),
)
return
()
@
staticmethod
# type: ignore
def
backward
(
ctx
,
*
grad
:
Tensor
,)
->
Tensors
:
return
tuple
(
grad
)
Tensors
=
Tuple
[
Tensor
,
...]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
class
RecvOperator
(
torch
.
autograd
.
Function
):
ExcInfo
=
Tuple
[
Type
[
BaseException
],
BaseException
,
TracebackType
]
"""Receive activations to the previous pipeline stage"""
@
staticmethod
# type: ignore
def
forward
(
ctx
,
dst_rank
:
int
,
tensor
:
Tensor
,
input_device
,
transport
:
Transport
,
index
:
int
)
->
Tensors
:
assert
dst_rank
==
torch
.
distributed
.
get_rank
()
ctx
.
transport
=
transport
ctx
.
index
=
index
result
=
transport
.
get_out_of_order
(
ACTIVATIONS_GRADS_QUEUE
,
index
)
def
maybe_requires_grad
(
t
:
Tensor
)
->
Tensor
:
if
t
.
dtype
.
is_floating_point
:
return
t
.
requires_grad_
()
return
t
return
tuple
(
maybe_requires_grad
(
r
)
for
r
in
result
)
@
staticmethod
# type: ignore
def
backward
(
ctx
,
*
grad
:
Tensor
,)
->
Tuple
[
Optional
[
Tensor
],
...]:
ranks
=
get_pipeline_parallel_ranks
()
this_rank
=
torch
.
distributed
.
get_rank
()
ctx
.
transport
.
send_message
(
PipeMessage
(
this_rank
,
ranks
[
ranks
.
index
(
this_rank
)
-
1
],
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
ctx
.
index
,
tensors
=
tuple
(
grad
),
),
)
return
(
None
,
None
,
None
,
None
,
None
)
# Queue is generic only in stubs.
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
...
@@ -140,7 +70,7 @@ def wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream)
...
@@ -140,7 +70,7 @@ def wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream)
batch
[:]
=
tuple
([
x
if
x
.
is_floating_point
()
else
x
.
detach
()
for
x
in
batch
])
batch
[:]
=
tuple
([
x
if
x
.
is_floating_point
()
else
x
.
detach
()
for
x
in
batch
])
def
clock_cycles
(
m
:
int
,
n
:
int
)
->
Iterable
[
Schedule
]:
def
clock_cycles
(
m
:
int
,
n
:
int
)
->
Iterable
[
List
[
Tuple
[
int
,
int
]]
]:
"""Generates schedules for each clock cycle."""
"""Generates schedules for each clock cycle."""
# m: number of micro-batches
# m: number of micro-batches
# n: number of partitions
# n: number of partitions
...
@@ -159,159 +89,45 @@ def clock_cycles(m: int, n: int) -> Iterable[Schedule]:
...
@@ -159,159 +89,45 @@ def clock_cycles(m: int, n: int) -> Iterable[Schedule]:
yield
[(
k
-
j
,
j
)
for
j
in
range
(
max
(
1
+
k
-
m
,
0
),
min
(
1
+
k
,
n
))]
yield
[(
k
-
j
,
j
)
for
j
in
range
(
max
(
1
+
k
-
m
,
0
),
min
(
1
+
k
,
n
))]
def
create_task
(
style
:
PipelineStyle
,
checkpoint_stop
:
int
,
i
:
int
,
j
:
int
,
batch
:
Batch
,
partition
:
nn
.
Sequential
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
streams
:
List
[
AbstractStream
],
)
->
Task
:
# Determine whether checkpointing or not.
if
i
<
checkpoint_stop
:
def
function
(
input
:
TensorOrTensors
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
i
],
chunk_id
:
int
=
i
,
part_id
:
int
=
j
,
)
->
TensorOrTensors
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
ret
=
partition
(
input
)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert
type
(
ret
)
is
not
list
,
"Only Tensor or Tuple of Tensor output is supported"
return
ret
chk
=
Checkpointing
(
function
,
batch
)
if
style
is
PipelineStyle
.
SingleProcess
:
task
=
Task
(
streams
[
j
],
compute
=
chk
.
checkpoint
,
finalize
=
chk
.
recompute
)
elif
style
in
[
PipelineStyle
.
MultiProcess
,
PipelineStyle
.
AsyncSchedule
]:
task
=
Task
(
None
,
compute
=
chk
.
checkpoint
,
finalize
=
chk
.
recompute
)
del
function
,
chk
# TODO(tom) maybe remove
else
:
def
compute
(
batch
:
Batch
=
batch
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
i
],
chunk_id
:
int
=
i
,
part_id
:
int
=
j
,
)
->
Batch
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
return
batch
.
call
(
partition
)
if
style
is
PipelineStyle
.
SingleProcess
:
task
=
Task
(
streams
[
j
],
compute
=
compute
,
finalize
=
None
)
elif
style
in
[
PipelineStyle
.
MultiProcess
,
PipelineStyle
.
AsyncSchedule
]:
task
=
Task
(
None
,
compute
=
compute
,
finalize
=
None
)
del
compute
# TODO(tom) maybe remove
return
task
class
Pipeline
:
class
Pipeline
:
"""The pipeline parallelism for Pipe."""
"""The pipeline parallelism for Pipe."""
def
__init__
(
def
__init__
(
self
,
self
,
partitions
:
List
[
nn
.
Sequential
],
partitions
:
List
[
nn
.
Sequential
],
devices
:
Optional
[
List
[
torch
.
device
]
]
,
devices
:
List
[
torch
.
device
],
copy_streams
:
Optional
[
List
[
List
[
AbstractStream
]]
]
,
copy_streams
:
List
[
List
[
AbstractStream
]],
skip_layout
:
SkipLayout
,
skip_layout
:
SkipLayout
,
checkpoint_stop
:
int
,
checkpoint_stop
:
int
,
style
:
PipelineStyle
,
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
final_stage
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
style
==
PipelineStyle
.
SingleProcess
:
self
.
partitions
=
partitions
self
.
partitions
=
partitions
else
:
self
.
mp_partitions
:
List
[
ModuleWrapper
]
=
cast
(
List
[
ModuleWrapper
],
partitions
)
self
.
devices
=
devices
self
.
devices
=
devices
self
.
copy_streams
=
copy_streams
self
.
copy_streams
=
copy_streams
self
.
skip_layout
=
skip_layout
self
.
skip_layout
=
skip_layout
self
.
__checkpoint_stop
=
checkpoint_stop
self
.
checkpoint_stop
=
checkpoint_stop
self
.
style
=
style
(
self
.
in_queues
,
self
.
out_queues
)
=
create_workers
(
devices
)
self
.
group
=
group
self
.
training
:
bool
if
style
in
[
PipelineStyle
.
MultiProcess
,
PipelineStyle
.
AsyncSchedule
]:
self
.
transport
=
MakeTransport
(
use_rpc
=
(
"OMPI_COMM_WORLD_RANK"
not
in
os
.
environ
)
or
(
"FORCE_RPC"
in
os
.
environ
),
worker_map
=
worker_map
,
input_device
=
input_device
,
)
self
.
input_device
=
input_device
self
.
all_at_once
=
False
self
.
callcount
=
0
self
.
final_stage
=
final_stage
if
self
.
style
is
PipelineStyle
.
SingleProcess
:
assert
self
.
devices
is
not
None
(
self
.
in_queues
,
self
.
out_queues
)
=
create_workers
(
self
.
devices
)
@
property
def
checkpoint_stop
(
self
)
->
int
:
# Disable checkpointing if in eval mode.
if
self
.
style
==
PipelineStyle
.
SingleProcess
:
training
=
self
.
partitions
[
0
].
training
else
:
training
=
self
.
mp_partitions
[
0
].
module
.
training
if
not
training
:
return
0
return
self
.
__checkpoint_stop
def
__del__
(
self
)
->
None
:
def
__del__
(
self
)
->
None
:
if
self
.
style
is
PipelineStyle
.
SingleProcess
:
join_workers
(
self
.
in_queues
,
self
.
out_queues
)
join_workers
(
self
.
in_queues
,
self
.
out_queues
)
def
run
(
self
,
training
:
bool
,
batches
:
List
[
Batch
],
event
:
Optional
[
Event
])
->
None
:
def
run
(
self
,
batches
:
List
[
Batch
])
->
None
:
"""Runs pipeline parallelism.
"""Runs pipeline parallelism.
It modifies the given batches in place.
It modifies the given batches in place.
"""
"""
self
.
training
=
training
partitions
=
self
.
partitions
devices
=
self
.
devices
skip_layout
=
self
.
skip_layout
m
=
len
(
batches
)
m
=
len
(
batches
)
n
=
len
(
partitions
)
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
len
(
batches
)
)]
skip_trackers
=
[
SkipTrackerThroughPotals
(
skip_layout
,
i
)
for
i
in
range
(
m
)]
if
self
.
style
is
PipelineStyle
.
SingleProcess
:
for
schedule
in
clock_cycles
(
m
,
n
):
n
=
len
(
self
.
partitions
)
self
.
fence
(
batches
,
schedule
,
skip_trackers
)
for
schedule
in
clock_cycles
(
m
,
n
):
self
.
fence
(
batches
,
schedule
,
skip_trackers
)
self
.
compute
(
batches
,
schedule
,
skip_trackers
)
elif
self
.
style
is
PipelineStyle
.
MultiProcess
:
assert
self
.
group
schedule
=
[(
i
,
self
.
group
.
rank
())
for
i
in
range
(
m
)]
self
.
compute
(
batches
,
schedule
,
skip_trackers
)
self
.
compute
(
batches
,
schedule
,
skip_trackers
)
elif
self
.
style
is
PipelineStyle
.
AsyncSchedule
:
assert
self
.
group
rank
=
self
.
group
.
rank
()
event_loop
=
AsyncEventLoop
(
self
.
mp_partitions
,
self
.
group
,
self
.
transport
,
self
.
training
,
self
.
checkpoint_stop
,
)
if
rank
==
0
and
not
self
.
final_stage
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event head"
)
event_loop
.
event_loop_head
(
batches
,
skip_trackers
,
event
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event head"
)
elif
self
.
final_stage
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event tail"
)
event_loop
.
event_loop_tail
(
batches
,
skip_trackers
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event tail"
)
else
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event loop"
)
event_loop
.
event_loop
(
len
(
batches
),
skip_trackers
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event loop"
)
self
.
callcount
+=
1
def
fence
(
def
fence
(
self
,
batches
:
List
[
Batch
],
schedule
:
List
[
Tuple
[
int
,
int
]],
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
self
,
batches
:
List
[
Batch
],
schedule
:
List
[
Tuple
[
int
,
int
]],
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
...
@@ -322,9 +138,6 @@ class Pipeline:
...
@@ -322,9 +138,6 @@ class Pipeline:
copy_streams
=
self
.
copy_streams
copy_streams
=
self
.
copy_streams
skip_layout
=
self
.
skip_layout
skip_layout
=
self
.
skip_layout
assert
copy_streams
assert
skip_layout
for
i
,
j
in
schedule
:
for
i
,
j
in
schedule
:
# Ensure that batches[i-1] is executed after batches[i] in
# Ensure that batches[i-1] is executed after batches[i] in
# backpropagation by an explicit dependency.
# backpropagation by an explicit dependency.
...
@@ -341,91 +154,92 @@ class Pipeline:
...
@@ -341,91 +154,92 @@ class Pipeline:
prev_stream
=
copy_streams
[
j
-
1
][
i
]
prev_stream
=
copy_streams
[
j
-
1
][
i
]
copy
(
batches
[
i
],
prev_stream
,
next_stream
)
copy
(
batches
[
i
],
prev_stream
,
next_stream
)
def
get_batch_from_previous_stage
(
def
compute
(
self
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
batches
:
List
[
Batch
]
self
,
batches
:
List
[
Batch
],
schedule
:
List
[
Tuple
[
int
,
int
]],
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
)
->
Batch
:
phony
=
torch
.
empty
(
0
,
device
=
self
.
input_device
,
requires_grad
=
True
)
result
=
RecvOperator
.
apply
(
torch
.
distributed
.
get_rank
(),
phony
,
self
.
input_device
,
self
.
transport
,
i
)
if
len
(
result
)
==
1
:
batch
=
Batch
(
result
[
0
],
i
)
else
:
batch
=
Batch
(
result
,
i
)
self
.
recv_skip_tensors
(
skip_trackers
,
batches
)
return
batch
def
send_skip_tensors
(
self
,
this_rank
:
int
,
ranks
:
List
[
int
],
batch
:
Batch
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
]
)
->
None
:
)
->
None
:
assert
self
.
group
"""Runs tasks with synchronization to copy streams."""
for
next_j
,
ns
,
name
in
self
.
skip_layout
.
copy_policy_by_src
(
self
.
group
.
rank
()):
partitions
=
self
.
partitions
life
=
skip_trackers
[
i
].
portals
[(
ns
,
name
)].
tensor_life
devices
=
self
.
devices
loaded
=
skip_trackers
[
i
].
load
(
batch
,
ns
,
name
)
copy_streams
=
self
.
copy_streams
if
loaded
is
not
None
:
checkpoint_stop
=
self
.
checkpoint_stop
tensors
=
tuple
([
loaded
])
else
:
tensors
=
tuple
()
self
.
transport
.
send_message
(
PipeMessage
(
this_rank
,
ranks
[
next_j
],
queue_name
=
SKIP_TENSOR_QUEUE
,
args
=
(
i
,
ns
,
name
,
life
),
tensors
=
tensors
,
),
sync
=
True
,
)
def
recv_skip_tensors
(
self
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
batches
:
List
[
Batch
])
->
None
:
while
True
:
try
:
message
=
self
.
transport
.
recv_message
(
SKIP_TENSOR_QUEUE
,
nowait
=
True
)
(
si
,
ns
,
name
,
life
)
=
message
.
args
value
:
Optional
[
TensorOrTensors
]
=
message
.
tensors
assert
isinstance
(
value
,
tuple
)
if
len
(
value
)
==
0
:
value
=
None
else
:
assert
len
(
value
)
==
1
value
=
value
[
0
]
skip_trackers
[
si
].
save
(
batches
[
si
],
ns
,
name
,
value
)
# Disable checkpointing if in eval mode.
old_life
=
skip_trackers
[
si
].
portals
[(
ns
,
name
)].
tensor_life
if
not
self
.
partitions
[
0
].
training
:
if
life
!=
0
:
checkpoint_stop
=
0
skip_trackers
[
si
].
portals
[(
ns
,
name
)].
tensor_life
=
life
except
QueueEmpty
:
break
def
execute_task
(
self
,
task
:
Task
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
])
->
Batch
:
n
=
len
(
partitions
)
batch
=
task
.
compute
()
streams
=
[
current_stream
(
d
)
for
d
in
devices
]
exc_info
:
Optional
[
ExcInfo
]
=
None
assert
self
.
group
# With checkpointing, the autograd graph looks like this diagram:
rank
=
self
.
group
.
rank
()
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for
i
,
j
in
schedule
:
batch
=
batches
[
i
]
partition
=
partitions
[
j
]
if
self
.
style
is
PipelineStyle
.
MultiProcess
and
not
self
.
final_stage
:
# Synchronize with the copied input. ([1] in the diagram)
ranks
=
get_pipeline_parallel_ranks
()
if
j
!=
0
:
this_rank
=
torch
.
distributed
.
get_rank
()
wait
(
batch
,
copy_streams
[
j
][
i
],
streams
[
j
])
# Determine whether checkpointing or not.
checkpoint
=
i
<
checkpoint_stop
if
checkpoint
:
def
function
(
input
:
TensorOrTensors
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
i
],
chunk_id
:
int
=
i
,
part_id
:
int
=
j
,
)
->
TensorOrTensors
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
return
partition
(
input
)
chk
=
Checkpointing
(
function
,
batch
)
task
=
Task
(
streams
[
j
],
compute
=
chk
.
checkpoint
,
finalize
=
chk
.
recompute
)
del
function
,
chk
self
.
send_skip_tensors
(
this_rank
,
ranks
,
batch
,
i
,
skip_trackers
)
else
:
SendOperator
.
apply
(
this_rank
,
ranks
[
ranks
.
index
(
this_rank
)
+
1
],
self
.
transport
,
[
*
batch
],
i
)
for
portal
in
skip_trackers
[
i
].
portals
.
values
():
def
compute
(
portal
.
pipeline
=
self
batch
:
Batch
=
batch
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
i
],
chunk_id
:
int
=
i
,
part_id
:
int
=
j
,
)
->
Batch
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
return
batch
.
call
(
partition
)
task
.
finalize
(
batch
)
task
=
Task
(
streams
[
j
],
compute
=
compute
,
finalize
=
None
)
del
compute
return
batch
# Compute tasks in parallel. ([2] in the diagram)
self
.
in_queues
[
j
].
put
(
task
)
def
finalize_tasks
(
self
,
n
:
int
,
schedule
:
Schedule
,
streams
:
List
[
AbstractStream
],
copy_streams
:
List
[
List
[
AbstractStream
]],
batches
:
List
[
Batch
],
)
->
None
:
exc_info
:
Optional
[
ExcInfo
]
=
None
for
i
,
j
in
schedule
:
for
i
,
j
in
schedule
:
ok
,
payload
=
self
.
out_queues
[
j
].
get
()
ok
,
payload
=
self
.
out_queues
[
j
].
get
()
...
@@ -446,8 +260,7 @@ class Pipeline:
...
@@ -446,8 +260,7 @@ class Pipeline:
# Finalize tasks. If checkpointing is enabled, here the
# Finalize tasks. If checkpointing is enabled, here the
# recomputation is scheduled at backpropagation. ([4] in the
# recomputation is scheduled at backpropagation. ([4] in the
# diagram)
# diagram)
assert
self
.
devices
with
use_device
(
devices
[
j
]):
with
use_device
(
self
.
devices
[
j
]):
task
.
finalize
(
batch
)
task
.
finalize
(
batch
)
batches
[
i
]
=
batch
batches
[
i
]
=
batch
...
@@ -455,147 +268,3 @@ class Pipeline:
...
@@ -455,147 +268,3 @@ class Pipeline:
# Fail at the first exception.
# Fail at the first exception.
if
exc_info
is
not
None
:
if
exc_info
is
not
None
:
raise
exc_info
[
0
].
with_traceback
(
exc_info
[
1
],
exc_info
[
2
])
raise
exc_info
[
0
].
with_traceback
(
exc_info
[
1
],
exc_info
[
2
])
def
compute
(
self
,
batches
:
List
[
Batch
],
schedule
:
List
[
Tuple
[
int
,
int
]],
skip_trackers
:
List
[
SkipTrackerThroughPotals
]
)
->
None
:
"""Runs tasks with synchronization to copy streams."""
devices
=
self
.
devices
copy_streams
=
self
.
copy_streams
if
self
.
style
is
PipelineStyle
.
SingleProcess
:
assert
devices
is
not
None
n
=
len
(
self
.
partitions
)
streams
=
[
current_stream
(
d
)
for
d
in
devices
]
elif
self
.
style
is
PipelineStyle
.
MultiProcess
:
assert
self
.
group
n
=
self
.
group
.
size
()
streams
=
[]
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for
i
,
j
in
schedule
:
batch
=
batches
[
i
]
if
self
.
style
is
PipelineStyle
.
SingleProcess
:
partition
=
self
.
partitions
[
j
]
# Synchronize with the copied input. ([1] in the diagram)
assert
copy_streams
if
j
!=
0
:
wait
(
batch
,
copy_streams
[
j
][
i
],
streams
[
j
])
task
=
create_task
(
self
.
style
,
self
.
checkpoint_stop
,
i
,
j
,
batch
,
partition
,
skip_trackers
,
streams
)
# Compute tasks in parallel. ([2] in the diagram)
self
.
in_queues
[
j
].
put
(
task
)
elif
self
.
style
is
PipelineStyle
.
MultiProcess
:
assert
len
(
self
.
mp_partitions
)
==
1
mp_partition
=
self
.
mp_partitions
[
0
]
assert
self
.
group
if
self
.
group
.
rank
()
!=
0
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
task
=
create_task
(
self
.
style
,
self
.
checkpoint_stop
,
i
,
j
,
batch
,
mp_partition
.
module
,
skip_trackers
,
streams
)
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
if
self
.
style
is
PipelineStyle
.
SingleProcess
:
assert
copy_streams
self
.
finalize_tasks
(
n
,
schedule
,
streams
,
copy_streams
,
batches
)
def
send_portal_grad
(
self
,
ns_name
:
Tuple
[
Namespace
,
str
],
index
:
int
,
grad
:
TensorOrTensors
)
->
None
:
dest
,
src
=
self
.
skip_layout
.
by_ns_name
.
get
(
ns_name
,
(
-
1
,
-
1
))
if
dest
==
src
:
return
ranks
=
get_pipeline_parallel_ranks
()
dst_rank
=
ranks
[
dest
]
if
dst_rank
==
torch
.
distributed
.
get_rank
():
return
if
isinstance
(
grad
,
Tensor
):
grad
=
tuple
([
grad
])
self
.
transport
.
send_message
(
PipeMessage
(
ranks
[
src
],
dst_rank
,
queue_name
=
PORTAL_QUEUE
,
args
=
(
ns_name
,
index
),
tensors
=
grad
),
sync
=
True
,
)
def
recv_portal_grad
(
self
,
expected_ns_name
:
Tuple
[
Namespace
,
str
],
expected_index
:
int
)
->
Tensor
:
message
=
self
.
transport
.
recv_message
(
PORTAL_QUEUE
)
(
ns_name
,
index
)
=
message
.
args
grad
=
message
.
tensors
assert
len
(
grad
)
==
1
result
=
grad
[
0
]
assert
index
==
expected_index
and
ns_name
==
expected_ns_name
return
result
def
back_helper
(
self
,
output
:
List
[
Batch
])
->
None
:
if
self
.
style
==
PipelineStyle
.
AsyncSchedule
:
return
o
=
list
(
output
)
tensors
:
Tensors
if
self
.
all_at_once
:
# FIXME(tom) allow specifying this branch when constructing Pipe(), add a test
grads
=
[]
for
i
,
batch
in
enumerate
(
o
):
rank
=
torch
.
distributed
.
get_rank
()
found
=
self
.
transport
.
get_out_of_order
(
ACTIVATIONS_GRADS_QUEUE
,
i
)
assert
len
(
found
)
==
1
grads
.
append
(
found
[
0
])
tensors
=
tuple
(
x
.
tensor_or_tensors
for
x
in
o
)
# type: ignore
try
:
torch
.
autograd
.
backward
(
tensors
,
grad_tensors
=
grads
,
retain_graph
=
True
)
except
Exception
as
e
:
raise
RuntimeError
(
"Autograd failed"
)
from
e
else
:
rank
=
torch
.
distributed
.
get_rank
()
for
batch
in
o
:
found
=
self
.
transport
.
get_out_of_order
(
ACTIVATIONS_GRADS_QUEUE
,
batch
.
index
)
if
batch
.
atomic
:
tensors
=
tuple
([
batch
.
tensor
])
else
:
tensors
=
batch
.
tensors
if
len
(
found
)
!=
len
(
tensors
):
raise
RuntimeError
(
"different number of tensors and gradients"
)
grads
=
[]
final_tensors
=
[]
for
i
,
tensor
in
enumerate
(
tensors
):
if
tensor
.
requires_grad
or
getattr
(
tensor
,
"grad_fn"
,
None
)
is
not
None
:
grads
.
append
(
found
[
i
])
final_tensors
.
append
(
tensor
)
try
:
torch
.
autograd
.
backward
(
final_tensors
,
grad_tensors
=
grads
,
retain_graph
=
True
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Autograd failed on
{
torch
.
distributed
.
get_rank
()
}
"
)
from
e
fairscale/nn/pipe/rpc.py
View file @
cae9b638
...
@@ -13,13 +13,13 @@ from torch.distributed.distributed_c10d import _get_global_rank
...
@@ -13,13 +13,13 @@ from torch.distributed.distributed_c10d import _get_global_rank
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
.
import
Pipe
from
.
multiprocess_pipe
import
MultiProcess
Pipe
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
TensorOrTensors
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
TensorOrTensors
DEFAULT_MAX_SOURCE_POSITIONS
=
1024
DEFAULT_MAX_SOURCE_POSITIONS
=
1024
DEFAULT_MAX_TARGET_POSITIONS
=
1024
DEFAULT_MAX_TARGET_POSITIONS
=
1024
PipeModel
:
Pipe
PipeModel
:
MultiProcess
Pipe
PipeResult
:
TensorOrTensors
PipeResult
:
TensorOrTensors
...
@@ -71,7 +71,7 @@ class PipeBackRedirect(torch.autograd.Function):
...
@@ -71,7 +71,7 @@ class PipeBackRedirect(torch.autograd.Function):
return
(
None
,
None
,
None
,
None
,
None
,
None
)
return
(
None
,
None
,
None
,
None
,
None
,
None
)
def
callback_with_model
(
callback
:
Callable
[[
Any
,
Pipe
],
None
],
ctx
:
Any
)
->
None
:
def
callback_with_model
(
callback
:
Callable
[[
Any
,
MultiProcess
Pipe
],
None
],
ctx
:
Any
)
->
None
:
try
:
try
:
group
=
get_pipeline_parallel_group
()
# FIXME(tom) handle dynamic group
group
=
get_pipeline_parallel_group
()
# FIXME(tom) handle dynamic group
set_device_based_on_group
(
group
)
set_device_based_on_group
(
group
)
...
@@ -105,10 +105,10 @@ class PipeRPCWrapper(nn.Module):
...
@@ -105,10 +105,10 @@ class PipeRPCWrapper(nn.Module):
else
:
else
:
kwargs
[
"group"
]
=
self
.
group
kwargs
[
"group"
]
=
self
.
group
kwargs
[
"style"
]
=
Pipe
.
AsyncSchedule
kwargs
[
"style"
]
=
MultiProcess
Pipe
.
AsyncSchedule
kwargs
[
"input_device"
]
=
torch
.
device
(
"cuda"
,
torch
.
cuda
.
current_device
())
kwargs
[
"input_device"
]
=
torch
.
device
(
"cuda"
,
torch
.
cuda
.
current_device
())
self
.
model
=
Pipe
(
*
args
,
**
kwargs
)
self
.
model
=
MultiProcess
Pipe
(
*
args
,
**
kwargs
)
self
.
worker_map
=
kwargs
[
"worker_map"
]
self
.
worker_map
=
kwargs
[
"worker_map"
]
self
.
_foreach_worker
(
self
.
_register_remote_model
,
args
=
(
args
,
kwargs
))
self
.
_foreach_worker
(
self
.
_register_remote_model
,
args
=
(
args
,
kwargs
))
self
.
model
.
cuda
()
self
.
model
.
cuda
()
...
@@ -121,7 +121,7 @@ class PipeRPCWrapper(nn.Module):
...
@@ -121,7 +121,7 @@ class PipeRPCWrapper(nn.Module):
futures
=
[
f
.
wait
()
for
f
in
futures
]
futures
=
[
f
.
wait
()
for
f
in
futures
]
def
foreach_worker
(
def
foreach_worker
(
self
,
callback
:
Callable
[[
Any
,
Pipe
],
None
],
ctx
:
Any
=
None
,
*
,
include_self
:
bool
=
False
self
,
callback
:
Callable
[[
Any
,
MultiProcess
Pipe
],
None
],
ctx
:
Any
=
None
,
*
,
include_self
:
bool
=
False
)
->
None
:
)
->
None
:
"""Call `callback` on each worker with the `ctx` and model local to that
"""Call `callback` on each worker with the `ctx` and model local to that
worker. e.g.
worker. e.g.
...
@@ -196,7 +196,9 @@ class PipeRPCWrapper(nn.Module):
...
@@ -196,7 +196,9 @@ class PipeRPCWrapper(nn.Module):
return
self
.
model
.
final_stage
return
self
.
model
.
final_stage
@
staticmethod
@
staticmethod
def
_recv_result
(
model
:
Pipe
,
shapes
:
SizeOrSizes
,
dtypes
:
DtypeOrDtypes
,
message
:
PipeMessage
)
->
TensorOrTensors
:
def
_recv_result
(
model
:
MultiProcessPipe
,
shapes
:
SizeOrSizes
,
dtypes
:
DtypeOrDtypes
,
message
:
PipeMessage
)
->
TensorOrTensors
:
group
=
get_pipeline_parallel_group
()
group
=
get_pipeline_parallel_group
()
set_device_based_on_group
(
group
)
set_device_based_on_group
(
group
)
...
@@ -243,7 +245,7 @@ class PipeRPCWrapper(nn.Module):
...
@@ -243,7 +245,7 @@ class PipeRPCWrapper(nn.Module):
set_device_based_on_group
(
group
)
set_device_based_on_group
(
group
)
kwargs
[
"group"
]
=
group
kwargs
[
"group"
]
=
group
kwargs
[
"input_device"
]
=
torch
.
device
(
"cuda"
,
torch
.
cuda
.
current_device
())
kwargs
[
"input_device"
]
=
torch
.
device
(
"cuda"
,
torch
.
cuda
.
current_device
())
model
=
Pipe
(
*
args
,
**
kwargs
)
model
=
MultiProcess
Pipe
(
*
args
,
**
kwargs
)
model
.
cuda
()
model
.
cuda
()
global
PipeModel
global
PipeModel
PipeModel
=
model
PipeModel
=
model
...
...
fairscale/nn/pipe/types.py
View file @
cae9b638
...
@@ -24,7 +24,6 @@ Tensors = Tuple[Tensor, ...]
...
@@ -24,7 +24,6 @@ Tensors = Tuple[Tensor, ...]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
InputDevice
=
Union
[
None
,
int
,
str
,
torch
.
device
]
InputDevice
=
Union
[
None
,
int
,
str
,
torch
.
device
]
Schedule
=
List
[
Tuple
[
int
,
int
]]
class
LazyModule
:
class
LazyModule
:
...
...
stubs/torch/multiprocessing/__init__.pyi
View file @
cae9b638
...
@@ -5,7 +5,7 @@ from typing import Any, Callable, Optional, Tuple
...
@@ -5,7 +5,7 @@ from typing import Any, Callable, Optional, Tuple
from torch import Tensor
from torch import Tensor
def spawn(
def spawn(
fn: Callable[
[Any]
, Any],
fn: Callable[
...
, Any],
args: Tuple[Optional[Any], ...] = (),
args: Tuple[Optional[Any], ...] = (),
nprocs: int = 1,
nprocs: int = 1,
join: bool = True,
join: bool = True,
...
...
tests/nn/model_parallel/test_layers.py
View file @
cae9b638
...
@@ -31,7 +31,7 @@ from torch.nn.parameter import Parameter
...
@@ -31,7 +31,7 @@ from torch.nn.parameter import Parameter
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
layers
from
fairscale.nn.model_parallel
import
layers
from
fairscale.nn.pipe
import
Pipe
from
fairscale.nn.pipe
import
MultiProcess
Pipe
from
fairscale.utils.testing
import
dist_init
,
get_world_sizes
,
set_random_seed
,
spawn_for_all_world_sizes
,
torch_spawn
from
fairscale.utils.testing
import
dist_init
,
get_world_sizes
,
set_random_seed
,
spawn_for_all_world_sizes
,
torch_spawn
...
@@ -319,7 +319,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
...
@@ -319,7 +319,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
print
(
"> testing Sequential + Pipe with model parallel size: {}, pipe: {}"
.
format
(
"> testing Sequential +
MultiProcess
Pipe with model parallel size: {}, pipe: {}"
.
format
(
model_parallel_size
,
pipe_world_size
model_parallel_size
,
pipe_world_size
)
)
)
)
...
@@ -431,13 +431,13 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
...
@@ -431,13 +431,13 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
model
[
2
].
weight
.
data
=
saved_weight_2
model
[
2
].
weight
.
data
=
saved_weight_2
worker_map
=
{
i
:
f
"Test
{
i
}
"
for
i
in
range
(
torch
.
distributed
.
get_world_size
())}
worker_map
=
{
i
:
f
"Test
{
i
}
"
for
i
in
range
(
torch
.
distributed
.
get_world_size
())}
style
=
Pipe
.
MultiProcess
# Pipe.AsyncSchedule
style
=
MultiProcess
Pipe
.
MultiProcess
#
MultiProcess
Pipe.AsyncSchedule
if
pipe_world_size
==
2
:
if
pipe_world_size
==
2
:
print
(
f
"actually doing pipe stuff now"
)
print
(
f
"actually doing pipe stuff now"
)
assert
torch
.
equal
(
saved_weight_0
,
model
[
0
].
weight
.
data
)
assert
torch
.
equal
(
saved_weight_0
,
model
[
0
].
weight
.
data
)
assert
torch
.
equal
(
saved_weight_2
,
model
[
2
].
weight
.
data
)
assert
torch
.
equal
(
saved_weight_2
,
model
[
2
].
weight
.
data
)
pipe_model
=
Pipe
(
pipe_model
=
MultiProcess
Pipe
(
model
,
model
,
[
2
,
1
],
[
2
,
1
],
style
=
style
,
style
=
style
,
...
@@ -507,7 +507,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
...
@@ -507,7 +507,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
failed
=
False
failed
=
False
with
torch
.
autograd
.
profiler
.
profile
()
as
prof
:
with
torch
.
autograd
.
profiler
.
profile
()
as
prof
:
try
:
try
:
if
style
==
Pipe
.
MultiProcess
:
if
style
==
MultiProcess
Pipe
.
MultiProcess
:
pipe_model
.
back_helper
(
pipe_output
)
pipe_model
.
back_helper
(
pipe_output
)
except
Exception
as
e
:
except
Exception
as
e
:
failed
=
True
failed
=
True
...
...
tests/nn/pipe_process/skip/test_gpipe.py
View file @
cae9b638
...
@@ -23,7 +23,7 @@ import pytest
...
@@ -23,7 +23,7 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn.pipe
import
LazyModule
,
Pipe
from
fairscale.nn.pipe
import
LazyModule
,
MultiProcess
Pipe
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip.portal
import
PortalBlue
,
PortalCopy
,
PortalOrange
from
fairscale.nn.pipe.skip.portal
import
PortalBlue
,
PortalCopy
,
PortalOrange
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
...
@@ -33,12 +33,12 @@ from fairscale.utils.testing import get_worker_map, torch_spawn
...
@@ -33,12 +33,12 @@ from fairscale.utils.testing import get_worker_map, torch_spawn
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"balance"
,
[[
3
],
[
1
,
2
],
[
2
,
1
],
[
1
,
1
,
1
]],
ids
=
[
"3"
,
"1:2"
,
"2:1"
,
"1:1:1"
])
@
pytest
.
mark
.
parametrize
(
"balance"
,
[[
3
],
[
1
,
2
],
[
2
,
1
],
[
1
,
1
,
1
]],
ids
=
[
"3"
,
"1:2"
,
"2:1"
,
"1:1:1"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
def
x1to3
(
balance
,
checkpoint
,
pipeline_style
):
def
x1to3
(
balance
,
checkpoint
,
pipeline_style
):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
if
pipeline_style
==
Pipe
.
AsyncSchedule
and
len
(
balance
)
>
1
:
if
pipeline_style
==
MultiProcess
Pipe
.
AsyncSchedule
and
len
(
balance
)
>
1
:
print
(
f
"skipping yarg"
)
print
(
f
"skipping yarg"
)
pytest
.
skip
(
"Skip tensors NYI for AsyncSchedule"
)
pytest
.
skip
(
"Skip tensors NYI for AsyncSchedule"
)
...
@@ -74,7 +74,7 @@ def x1to3(balance, checkpoint, pipeline_style):
...
@@ -74,7 +74,7 @@ def x1to3(balance, checkpoint, pipeline_style):
return
output
return
output
model
=
nn
.
Sequential
(
Layer1
(),
Layer2
(),
Layer3
())
model
=
nn
.
Sequential
(
Layer1
(),
Layer2
(),
Layer3
())
model
=
Pipe
(
model
=
MultiProcess
Pipe
(
model
,
model
,
balance
,
balance
,
chunks
=
3
,
chunks
=
3
,
...
@@ -106,9 +106,10 @@ def x1to3(balance, checkpoint, pipeline_style):
...
@@ -106,9 +106,10 @@ def x1to3(balance, checkpoint, pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcessPipe
.
MultiProcess
,
MultiProcessPipe
.
AsyncSchedule
])
@
pytest
.
mark
.
skip
(
reason
=
"flaky test"
)
def
none_skip
(
pipeline_style
):
def
none_skip
(
pipeline_style
):
if
pipeline_style
==
Pipe
.
AsyncSchedule
:
if
pipeline_style
==
MultiProcess
Pipe
.
AsyncSchedule
:
pytest
.
skip
(
"Skip tensors NYI for AsyncSchedule"
)
pytest
.
skip
(
"Skip tensors NYI for AsyncSchedule"
)
@
skippable
(
stash
=
[
"none"
])
@
skippable
(
stash
=
[
"none"
])
...
@@ -125,7 +126,7 @@ def none_skip(pipeline_style):
...
@@ -125,7 +126,7 @@ def none_skip(pipeline_style):
return
input
return
input
model
=
nn
.
Sequential
(
Stash
(),
Pop
())
model
=
nn
.
Sequential
(
Stash
(),
Pop
())
model
=
Pipe
(
model
=
MultiProcess
Pipe
(
model
,
model
,
[
1
,
1
],
[
1
,
1
],
style
=
pipeline_style
,
style
=
pipeline_style
,
...
@@ -160,7 +161,7 @@ def none_skip(pipeline_style):
...
@@ -160,7 +161,7 @@ def none_skip(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
lazy_skippable_error
(
pipeline_style
):
def
lazy_skippable_error
(
pipeline_style
):
"""Using skippable layers in combination with lazy construction is currently
"""Using skippable layers in combination with lazy construction is currently
not supported, check that it raises an Exception"""
not supported, check that it raises an Exception"""
...
@@ -180,6 +181,6 @@ def lazy_skippable_error(pipeline_style):
...
@@ -180,6 +181,6 @@ def lazy_skippable_error(pipeline_style):
]
]
with
pytest
.
raises
(
ValueError
,
match
=
"Can't use Skippable layers with multi-process pipe and lazy construction"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Can't use Skippable layers with multi-process pipe and lazy construction"
):
Pipe
(
MultiProcess
Pipe
(
model
,
[
2
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
model
,
[
2
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
)
)
tests/nn/pipe_process/skip/test_leak.py
View file @
cae9b638
...
@@ -23,7 +23,7 @@ import pytest
...
@@ -23,7 +23,7 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn.pipe
import
Pipe
,
is_checkpointing
,
is_recomputing
from
fairscale.nn.pipe
import
MultiProcess
Pipe
,
is_checkpointing
,
is_recomputing
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip.tracker
import
current_skip_tracker
from
fairscale.nn.pipe.skip.tracker
import
current_skip_tracker
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
...
@@ -46,7 +46,7 @@ class Pop(nn.Module):
...
@@ -46,7 +46,7 @@ class Pop(nn.Module):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"train"
,
[
True
,
False
],
ids
=
[
"train"
,
"eval"
])
@
pytest
.
mark
.
parametrize
(
"train"
,
[
True
,
False
],
ids
=
[
"train"
,
"eval"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"always"
,
"except_last"
,
"never"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"always"
,
"except_last"
,
"never"
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
def
delete_portal_tensor
(
train
,
checkpoint
,
pipeline_style
):
def
delete_portal_tensor
(
train
,
checkpoint
,
pipeline_style
):
...
@@ -60,7 +60,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
...
@@ -60,7 +60,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+
# +----------+ +------------+ +------------+ +----------+
if
pipeline_style
==
Pipe
.
AsyncSchedule
:
if
pipeline_style
==
MultiProcess
Pipe
.
AsyncSchedule
:
pytest
.
skip
(
"Skip tensors NYI for AsyncSchedule"
)
pytest
.
skip
(
"Skip tensors NYI for AsyncSchedule"
)
def
portal_tensor_life_is
(
tensor_life
,
skip_tracker
=
None
):
def
portal_tensor_life_is
(
tensor_life
,
skip_tracker
=
None
):
...
@@ -114,7 +114,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
...
@@ -114,7 +114,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
return
self
.
F
.
apply
(
input
)
return
self
.
F
.
apply
(
input
)
model
=
nn
.
Sequential
(
NoPortalTensorAtBackward
(),
stash_
,
pop_
)
model
=
nn
.
Sequential
(
NoPortalTensorAtBackward
(),
stash_
,
pop_
)
model
=
Pipe
(
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
2
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
checkpoint
,
model
,
balance
=
[
2
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
checkpoint
,
)
)
...
...
tests/nn/pipe_process/test_bugs.py
View file @
cae9b638
...
@@ -22,15 +22,15 @@ import torch
...
@@ -22,15 +22,15 @@ import torch
from
torch
import
nn
from
torch
import
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
fairscale.nn.pipe
import
Pipe
from
fairscale.nn.pipe
import
MultiProcess
Pipe
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
python_autograd_function
(
pipeline_style
):
def
python_autograd_function
(
pipeline_style
):
# FIXME deadlock with Pipe.AsyncSchedule?
# FIXME deadlock with
MultiProcess
Pipe.AsyncSchedule?
# A Python autograd function might fail with this error:
# A Python autograd function might fail with this error:
#
#
# RuntimeError: Returning Variables sharing storage with other Variables
# RuntimeError: Returning Variables sharing storage with other Variables
...
@@ -57,7 +57,9 @@ def python_autograd_function(pipeline_style):
...
@@ -57,7 +57,9 @@ def python_autograd_function(pipeline_style):
return
Identity
.
apply
(
input
)
return
Identity
.
apply
(
input
)
model
=
nn
.
Sequential
(
M
(),
M
())
model
=
nn
.
Sequential
(
M
(),
M
())
model
=
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
).
cuda
()
model
=
MultiProcessPipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
).
cuda
()
model
.
eval
()
model
.
eval
()
x
=
torch
.
rand
(
42
)
x
=
torch
.
rand
(
42
)
...
@@ -71,7 +73,7 @@ def python_autograd_function(pipeline_style):
...
@@ -71,7 +73,7 @@ def python_autograd_function(pipeline_style):
@
torch_spawn
([
3
])
@
torch_spawn
([
3
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
exception_no_hang
(
pipeline_style
):
def
exception_no_hang
(
pipeline_style
):
# In v0.0.2, once a failed partition receives a normal message
# In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was
# (non-closing) for the next micro-batch, a hang occured. The reason was
...
@@ -90,7 +92,7 @@ def exception_no_hang(pipeline_style):
...
@@ -90,7 +92,7 @@ def exception_no_hang(pipeline_style):
raise
ExpectedException
()
raise
ExpectedException
()
model
=
nn
.
Sequential
(
Pass
(),
Pass
(),
Raise
())
model
=
nn
.
Sequential
(
Pass
(),
Pass
(),
Raise
())
model
=
Pipe
(
model
,
[
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
3
)
model
=
MultiProcess
Pipe
(
model
,
[
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
3
)
model
.
eval
()
model
.
eval
()
if
model
.
group
.
rank
()
==
2
:
if
model
.
group
.
rank
()
==
2
:
...
@@ -104,7 +106,7 @@ def exception_no_hang(pipeline_style):
...
@@ -104,7 +106,7 @@ def exception_no_hang(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"2 cuda devices required"
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"2 cuda devices required"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
tuple_wait
(
cuda_sleep
,
pipeline_style
):
def
tuple_wait
(
cuda_sleep
,
pipeline_style
):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility
# Under this behavior, if checkpointing was disabled, there's a possibility
...
@@ -133,7 +135,7 @@ def tuple_wait(cuda_sleep, pipeline_style):
...
@@ -133,7 +135,7 @@ def tuple_wait(cuda_sleep, pipeline_style):
return
a
+
b
+
c
return
a
+
b
+
c
model
=
nn
.
Sequential
(
Layer1
(),
Layer2
())
model
=
nn
.
Sequential
(
Layer1
(),
Layer2
())
model
=
Pipe
(
model
=
MultiProcess
Pipe
(
model
,
model
,
[
1
,
1
],
[
1
,
1
],
style
=
pipeline_style
,
style
=
pipeline_style
,
...
@@ -158,7 +160,7 @@ def tuple_wait(cuda_sleep, pipeline_style):
...
@@ -158,7 +160,7 @@ def tuple_wait(cuda_sleep, pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
parallel_randoms
(
pipeline_style
):
def
parallel_randoms
(
pipeline_style
):
class
Dropouts
(
nn
.
Module
):
class
Dropouts
(
nn
.
Module
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -170,7 +172,7 @@ def parallel_randoms(pipeline_style):
...
@@ -170,7 +172,7 @@ def parallel_randoms(pipeline_style):
x
=
torch
.
rand
(
10
,
10
,
requires_grad
=
True
).
cuda
()
x
=
torch
.
rand
(
10
,
10
,
requires_grad
=
True
).
cuda
()
x
.
retain_grad
()
x
.
retain_grad
()
model
=
Pipe
(
model
=
MultiProcess
Pipe
(
model
,
model
,
[
1
,
1
],
[
1
,
1
],
style
=
pipeline_style
,
style
=
pipeline_style
,
...
...
tests/nn/pipe_process/test_inplace.py
View file @
cae9b638
...
@@ -21,20 +21,20 @@ import pytest
...
@@ -21,20 +21,20 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn.pipe
import
Pipe
from
fairscale.nn.pipe
import
MultiProcess
Pipe
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
inplace_on_requires_grad
(
pipeline_style
):
def
inplace_on_requires_grad
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
ReLU
(
inplace
=
True
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
ReLU
(
inplace
=
True
))
model
=
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
model
=
MultiProcess
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
x
=
torch
.
rand
(
1
)
x
=
torch
.
rand
(
1
)
if
pipeline_style
==
Pipe
.
AsyncSchedule
and
model
.
group
.
rank
()
==
0
:
if
pipeline_style
==
MultiProcess
Pipe
.
AsyncSchedule
and
model
.
group
.
rank
()
==
0
:
# With AsyncSchedule, model will wait forever for gradients if not eval
# With AsyncSchedule, model will wait forever for gradients if not eval
model
.
eval
()
model
.
eval
()
...
@@ -50,12 +50,12 @@ def inplace_on_requires_grad(pipeline_style):
...
@@ -50,12 +50,12 @@ def inplace_on_requires_grad(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
inplace_on_not_requires_grad
(
pipeline_style
):
def
inplace_on_not_requires_grad
(
pipeline_style
):
# In-place operation on a tensor not requiring grad doesn't cause a
# In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case.
# RuntimeError. Currently, we cannot detect this case.
model
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
))
model
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
))
model
=
Pipe
(
model
,
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
model
=
MultiProcess
Pipe
(
model
,
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
x
=
torch
.
rand
(
1
)
x
=
torch
.
rand
(
1
)
y
=
model
(
x
)
y
=
model
(
x
)
...
@@ -70,7 +70,7 @@ def inplace_on_not_requires_grad(pipeline_style):
...
@@ -70,7 +70,7 @@ def inplace_on_not_requires_grad(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
inplace_incorrect_grad
(
pipeline_style
):
def
inplace_incorrect_grad
(
pipeline_style
):
class
M
(
nn
.
Module
):
class
M
(
nn
.
Module
):
def
forward
(
self
,
foo_bar
):
def
forward
(
self
,
foo_bar
):
...
@@ -88,7 +88,7 @@ def inplace_incorrect_grad(pipeline_style):
...
@@ -88,7 +88,7 @@ def inplace_incorrect_grad(pipeline_style):
return
foo
*
bar
return
foo
*
bar
model
=
nn
.
Sequential
(
M
())
model
=
nn
.
Sequential
(
M
())
model
=
Pipe
(
model
,
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
model
=
MultiProcess
Pipe
(
model
,
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
foo
=
torch
.
tensor
([
1.0
],
requires_grad
=
True
)
foo
=
torch
.
tensor
([
1.0
],
requires_grad
=
True
)
bar
=
torch
.
tensor
([
1.0
])
bar
=
torch
.
tensor
([
1.0
])
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment