Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
b5c3638f
Unverified
Commit
b5c3638f
authored
Jan 30, 2021
by
msbaines
Committed by
GitHub
Jan 30, 2021
Browse files
[refactor] pipe: move async-specific code out of MultiProcessPipeline (#345)
parent
a8dd9254
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
68 additions
and
100 deletions
+68
-100
fairscale/nn/pipe/async_pipe.py
fairscale/nn/pipe/async_pipe.py
+3
-4
fairscale/nn/pipe/async_pipeline.py
fairscale/nn/pipe/async_pipeline.py
+46
-0
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+2
-6
fairscale/nn/pipe/multiprocess_pipeline.py
fairscale/nn/pipe/multiprocess_pipeline.py
+17
-84
fairscale/nn/pipe/types.py
fairscale/nn/pipe/types.py
+0
-6
No files found.
fairscale/nn/pipe/async_pipe.py
View file @
b5c3638f
...
...
@@ -11,11 +11,11 @@ from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
import
torch
from
torch
import
Tensor
,
nn
from
.async_pipeline
import
AsyncPipeline
from
.async_schedule
import
Invocation
,
Location
,
ModuleWrapper
from
.multiprocess_pipe
import
MultiProcessPipe
,
check_balance
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.skip.skippable
import
Skippable
from
.types
import
LazyModule
,
PipelineStyle
from
.types
import
LazyModule
if
TYPE_CHECKING
:
Module
=
nn
.
Module
[
TensorOrTensors
]
...
...
@@ -43,11 +43,10 @@ class AsyncPipe(MultiProcessPipe):
# The micro-batch index where the checkpointing stops.
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
self
.
pipeline
=
MultiProcess
Pipeline
(
self
.
pipeline
=
Async
Pipeline
(
self
.
partitions
,
self
.
_skip_layout
,
checkpoint_stop
,
style
=
PipelineStyle
.
AsyncSchedule
,
group
=
self
.
group
,
worker_map
=
self
.
worker_map
,
input_device
=
self
.
input_device
,
...
...
fairscale/nn/pipe/async_pipeline.py
0 → 100644
View file @
b5c3638f
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
logging
from
threading
import
Event
from
typing
import
List
,
Optional
import
torch
from
.async_schedule
import
AsyncEventLoop
from
.microbatch
import
Batch
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.skip.tracker
import
SkipTrackerThroughPotals
class
AsyncPipeline
(
MultiProcessPipeline
):
def
run
(
self
,
training
:
bool
,
batches
:
List
[
Batch
],
event
:
Optional
[
Event
])
->
None
:
"""Runs pipeline parallelism.
It modifies the given batches in place.
"""
self
.
training
=
training
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
len
(
batches
))]
rank
=
self
.
group
.
rank
()
event_loop
=
AsyncEventLoop
(
self
.
partitions
,
self
.
group
,
self
.
transport
,
self
.
training
,
self
.
checkpoint_stop
,)
if
rank
==
0
and
not
self
.
final_stage
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event head"
)
event_loop
.
event_loop_head
(
batches
,
skip_trackers
,
event
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event head"
)
elif
self
.
final_stage
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event tail"
)
event_loop
.
event_loop_tail
(
batches
,
skip_trackers
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event tail"
)
else
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event loop"
)
event_loop
.
event_loop
(
len
(
batches
),
skip_trackers
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event loop"
)
def
back_helper
(
self
,
output
:
List
[
Batch
])
->
None
:
pass
fairscale/nn/pipe/multiprocess_pipe.py
View file @
b5c3638f
...
...
@@ -37,7 +37,7 @@ from .multiprocess_pipeline import MultiProcessPipeline
from
.phony
import
get_phony
from
.skip.layout
import
SkipLayout
,
inspect_skip_layout
from
.skip.skippable
import
Skippable
,
verify_skippables
from
.types
import
LazyModule
,
PipelineStyle
from
.types
import
LazyModule
__all__
=
[
"MultiProcessPipe"
,
"LazyModule"
]
...
...
@@ -202,11 +202,8 @@ class MultiProcessPipe(Module):
list of number of layers in each partition
Keyword Args:
style (PipelineStyle):
whether to use a single process for all pipeline stages or to assign
one stage per process
group (ProcessGroup):
specific to `style=MultiProcess`,
the process group that all
the process group that all
pipeline stages are a member of. Defaults to
`get_pipeline_parallel_group()`
worker_map (Dict[int, str]):
...
...
@@ -374,7 +371,6 @@ class MultiProcessPipe(Module):
self
.
partitions
,
self
.
_skip_layout
,
checkpoint_stop
,
style
=
PipelineStyle
.
MultiProcess
,
group
=
self
.
group
,
worker_map
=
self
.
worker_map
,
input_device
=
self
.
input_device
,
...
...
fairscale/nn/pipe/multiprocess_pipeline.py
View file @
b5c3638f
...
...
@@ -17,7 +17,6 @@
# limitations under the License.
"""The multiprocess pipeline parallelism of Pipe."""
import
logging
import
os
from
queue
import
Empty
as
QueueEmpty
from
queue
import
Queue
...
...
@@ -31,22 +30,14 @@ from torch.autograd.profiler import record_function
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_ranks
from
.async_schedule
import
AsyncEventLoop
,
ModuleWrapper
from
.async_schedule
import
ModuleWrapper
from
.checkpoint
import
Checkpointing
from
.messages
import
MakeTransport
,
Transport
from
.microbatch
import
Batch
from
.skip
import
Namespace
from
.skip.layout
import
SkipLayout
from
.skip.tracker
import
SkipTrackerThroughPotals
,
use_skip_tracker
from
.types
import
(
ACTIVATIONS_GRADS_QUEUE
,
PORTAL_QUEUE
,
SKIP_TENSOR_QUEUE
,
PipelineStyle
,
PipeMessage
,
TensorOrTensors
,
Tensors
,
)
from
.types
import
ACTIVATIONS_GRADS_QUEUE
,
PORTAL_QUEUE
,
SKIP_TENSOR_QUEUE
,
PipeMessage
,
TensorOrTensors
,
Tensors
from
.worker
import
Task
__all__
:
List
[
str
]
=
[]
...
...
@@ -174,8 +165,8 @@ class MultiProcessPipeline:
partitions
:
List
[
ModuleWrapper
],
skip_layout
:
SkipLayout
,
checkpoint_stop
:
int
,
style
:
PipelineStyle
,
group
:
torch
.
distributed
.
ProcessGroup
,
*
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
final_stage
:
bool
=
False
,
...
...
@@ -183,7 +174,6 @@ class MultiProcessPipeline:
self
.
partitions
=
partitions
self
.
skip_layout
=
skip_layout
self
.
__checkpoint_stop
=
checkpoint_stop
self
.
style
=
style
self
.
group
=
group
self
.
training
:
bool
self
.
transport
=
MakeTransport
(
...
...
@@ -192,7 +182,6 @@ class MultiProcessPipeline:
input_device
=
input_device
,
)
self
.
input_device
=
input_device
self
.
callcount
=
0
self
.
final_stage
=
final_stage
@
property
...
...
@@ -214,30 +203,22 @@ class MultiProcessPipeline:
m
=
len
(
batches
)
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
len
(
batches
)
)]
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
m
)]
if
self
.
style
is
PipelineStyle
.
MultiProcess
:
schedule
=
[(
i
,
self
.
group
.
rank
())
for
i
in
range
(
m
)]
self
.
compute
(
batches
,
schedule
,
skip_trackers
)
elif
self
.
style
is
PipelineStyle
.
AsyncSchedule
:
rank
=
self
.
group
.
rank
()
event_loop
=
AsyncEventLoop
(
self
.
partitions
,
self
.
group
,
self
.
transport
,
self
.
training
,
self
.
checkpoint_stop
,
)
if
rank
==
0
and
not
self
.
final_stage
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event head"
)
event_loop
.
event_loop_head
(
batches
,
skip_trackers
,
event
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event head"
)
elif
self
.
final_stage
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event tail"
)
event_loop
.
event_loop_tail
(
batches
,
skip_trackers
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event tail"
)
schedule
=
[(
i
,
self
.
group
.
rank
())
for
i
in
range
(
m
)]
for
i
,
j
in
schedule
:
assert
len
(
self
.
partitions
)
==
1
partition
=
self
.
partitions
[
0
]
if
self
.
group
.
rank
()
!=
0
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
else
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event loop"
)
event_loop
.
event_loop
(
len
(
batches
),
skip_trackers
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event loop"
)
batch
=
batches
[
i
]
self
.
callcount
+=
1
task
=
create_task
(
self
.
checkpoint_stop
,
i
,
j
,
batch
,
partition
.
module
,
skip_trackers
)
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
def
get_batch_from_previous_stage
(
self
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
batches
:
List
[
Batch
]
...
...
@@ -299,7 +280,7 @@ class MultiProcessPipeline:
rank
=
self
.
group
.
rank
()
if
self
.
style
is
PipelineStyle
.
MultiProcess
and
not
self
.
final_stage
:
if
not
self
.
final_stage
:
ranks
=
get_pipeline_parallel_ranks
()
this_rank
=
torch
.
distributed
.
get_rank
()
...
...
@@ -313,51 +294,6 @@ class MultiProcessPipeline:
return
batch
def
compute
(
self
,
batches
:
List
[
Batch
],
schedule
:
List
[
Tuple
[
int
,
int
]],
skip_trackers
:
List
[
SkipTrackerThroughPotals
]
)
->
None
:
"""Runs tasks with synchronization to copy streams."""
assert
self
.
style
is
PipelineStyle
.
MultiProcess
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for
i
,
j
in
schedule
:
assert
len
(
self
.
partitions
)
==
1
partition
=
self
.
partitions
[
0
]
if
self
.
group
.
rank
()
!=
0
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
else
:
batch
=
batches
[
i
]
task
=
create_task
(
self
.
checkpoint_stop
,
i
,
j
,
batch
,
partition
.
module
,
skip_trackers
)
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
def
send_portal_grad
(
self
,
ns_name
:
Tuple
[
Namespace
,
str
],
index
:
int
,
grad
:
TensorOrTensors
)
->
None
:
dest
,
src
=
self
.
skip_layout
.
by_ns_name
.
get
(
ns_name
,
(
-
1
,
-
1
))
if
dest
==
src
:
...
...
@@ -385,9 +321,6 @@ class MultiProcessPipeline:
return
result
def
back_helper
(
self
,
output
:
List
[
Batch
])
->
None
:
if
self
.
style
==
PipelineStyle
.
AsyncSchedule
:
return
tensors
:
Tensors
rank
=
torch
.
distributed
.
get_rank
()
...
...
fairscale/nn/pipe/types.py
View file @
b5c3638f
...
...
@@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -34,11 +33,6 @@ class LazyModule:
return
self
.
function
()
class
PipelineStyle
(
Enum
):
MultiProcess
=
auto
()
AsyncSchedule
=
auto
()
@
dataclass
(
init
=
False
)
class
PipeMessage
:
src
:
int
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment