Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
b5c3638f
Unverified
Commit
b5c3638f
authored
Jan 30, 2021
by
msbaines
Committed by
GitHub
Jan 30, 2021
Browse files
[refactor] pipe: move async-specific code out of MultiProcessPipeline (#345)
parent
a8dd9254
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
68 additions
and
100 deletions
+68
-100
fairscale/nn/pipe/async_pipe.py
fairscale/nn/pipe/async_pipe.py
+3
-4
fairscale/nn/pipe/async_pipeline.py
fairscale/nn/pipe/async_pipeline.py
+46
-0
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+2
-6
fairscale/nn/pipe/multiprocess_pipeline.py
fairscale/nn/pipe/multiprocess_pipeline.py
+17
-84
fairscale/nn/pipe/types.py
fairscale/nn/pipe/types.py
+0
-6
No files found.
fairscale/nn/pipe/async_pipe.py
View file @
b5c3638f
...
@@ -11,11 +11,11 @@ from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
...
@@ -11,11 +11,11 @@ from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
import
torch
import
torch
from
torch
import
Tensor
,
nn
from
torch
import
Tensor
,
nn
from
.async_pipeline
import
AsyncPipeline
from
.async_schedule
import
Invocation
,
Location
,
ModuleWrapper
from
.async_schedule
import
Invocation
,
Location
,
ModuleWrapper
from
.multiprocess_pipe
import
MultiProcessPipe
,
check_balance
from
.multiprocess_pipe
import
MultiProcessPipe
,
check_balance
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.skip.skippable
import
Skippable
from
.skip.skippable
import
Skippable
from
.types
import
LazyModule
,
PipelineStyle
from
.types
import
LazyModule
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
Module
=
nn
.
Module
[
TensorOrTensors
]
Module
=
nn
.
Module
[
TensorOrTensors
]
...
@@ -43,11 +43,10 @@ class AsyncPipe(MultiProcessPipe):
...
@@ -43,11 +43,10 @@ class AsyncPipe(MultiProcessPipe):
# The micro-batch index where the checkpointing stops.
# The micro-batch index where the checkpointing stops.
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
self
.
pipeline
=
MultiProcess
Pipeline
(
self
.
pipeline
=
Async
Pipeline
(
self
.
partitions
,
self
.
partitions
,
self
.
_skip_layout
,
self
.
_skip_layout
,
checkpoint_stop
,
checkpoint_stop
,
style
=
PipelineStyle
.
AsyncSchedule
,
group
=
self
.
group
,
group
=
self
.
group
,
worker_map
=
self
.
worker_map
,
worker_map
=
self
.
worker_map
,
input_device
=
self
.
input_device
,
input_device
=
self
.
input_device
,
...
...
fairscale/nn/pipe/async_pipeline.py
0 → 100644
View file @
b5c3638f
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
logging
from
threading
import
Event
from
typing
import
List
,
Optional
import
torch
from
.async_schedule
import
AsyncEventLoop
from
.microbatch
import
Batch
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.skip.tracker
import
SkipTrackerThroughPotals
class
AsyncPipeline
(
MultiProcessPipeline
):
def
run
(
self
,
training
:
bool
,
batches
:
List
[
Batch
],
event
:
Optional
[
Event
])
->
None
:
"""Runs pipeline parallelism.
It modifies the given batches in place.
"""
self
.
training
=
training
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
len
(
batches
))]
rank
=
self
.
group
.
rank
()
event_loop
=
AsyncEventLoop
(
self
.
partitions
,
self
.
group
,
self
.
transport
,
self
.
training
,
self
.
checkpoint_stop
,)
if
rank
==
0
and
not
self
.
final_stage
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event head"
)
event_loop
.
event_loop_head
(
batches
,
skip_trackers
,
event
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event head"
)
elif
self
.
final_stage
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event tail"
)
event_loop
.
event_loop_tail
(
batches
,
skip_trackers
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event tail"
)
else
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event loop"
)
event_loop
.
event_loop
(
len
(
batches
),
skip_trackers
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event loop"
)
def
back_helper
(
self
,
output
:
List
[
Batch
])
->
None
:
pass
fairscale/nn/pipe/multiprocess_pipe.py
View file @
b5c3638f
...
@@ -37,7 +37,7 @@ from .multiprocess_pipeline import MultiProcessPipeline
...
@@ -37,7 +37,7 @@ from .multiprocess_pipeline import MultiProcessPipeline
from
.phony
import
get_phony
from
.phony
import
get_phony
from
.skip.layout
import
SkipLayout
,
inspect_skip_layout
from
.skip.layout
import
SkipLayout
,
inspect_skip_layout
from
.skip.skippable
import
Skippable
,
verify_skippables
from
.skip.skippable
import
Skippable
,
verify_skippables
from
.types
import
LazyModule
,
PipelineStyle
from
.types
import
LazyModule
__all__
=
[
"MultiProcessPipe"
,
"LazyModule"
]
__all__
=
[
"MultiProcessPipe"
,
"LazyModule"
]
...
@@ -202,11 +202,8 @@ class MultiProcessPipe(Module):
...
@@ -202,11 +202,8 @@ class MultiProcessPipe(Module):
list of number of layers in each partition
list of number of layers in each partition
Keyword Args:
Keyword Args:
style (PipelineStyle):
whether to use a single process for all pipeline stages or to assign
one stage per process
group (ProcessGroup):
group (ProcessGroup):
specific to `style=MultiProcess`,
the process group that all
the process group that all
pipeline stages are a member of. Defaults to
pipeline stages are a member of. Defaults to
`get_pipeline_parallel_group()`
`get_pipeline_parallel_group()`
worker_map (Dict[int, str]):
worker_map (Dict[int, str]):
...
@@ -374,7 +371,6 @@ class MultiProcessPipe(Module):
...
@@ -374,7 +371,6 @@ class MultiProcessPipe(Module):
self
.
partitions
,
self
.
partitions
,
self
.
_skip_layout
,
self
.
_skip_layout
,
checkpoint_stop
,
checkpoint_stop
,
style
=
PipelineStyle
.
MultiProcess
,
group
=
self
.
group
,
group
=
self
.
group
,
worker_map
=
self
.
worker_map
,
worker_map
=
self
.
worker_map
,
input_device
=
self
.
input_device
,
input_device
=
self
.
input_device
,
...
...
fairscale/nn/pipe/multiprocess_pipeline.py
View file @
b5c3638f
...
@@ -17,7 +17,6 @@
...
@@ -17,7 +17,6 @@
# limitations under the License.
# limitations under the License.
"""The multiprocess pipeline parallelism of Pipe."""
"""The multiprocess pipeline parallelism of Pipe."""
import
logging
import
os
import
os
from
queue
import
Empty
as
QueueEmpty
from
queue
import
Empty
as
QueueEmpty
from
queue
import
Queue
from
queue
import
Queue
...
@@ -31,22 +30,14 @@ from torch.autograd.profiler import record_function
...
@@ -31,22 +30,14 @@ from torch.autograd.profiler import record_function
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_ranks
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_ranks
from
.async_schedule
import
AsyncEventLoop
,
ModuleWrapper
from
.async_schedule
import
ModuleWrapper
from
.checkpoint
import
Checkpointing
from
.checkpoint
import
Checkpointing
from
.messages
import
MakeTransport
,
Transport
from
.messages
import
MakeTransport
,
Transport
from
.microbatch
import
Batch
from
.microbatch
import
Batch
from
.skip
import
Namespace
from
.skip
import
Namespace
from
.skip.layout
import
SkipLayout
from
.skip.layout
import
SkipLayout
from
.skip.tracker
import
SkipTrackerThroughPotals
,
use_skip_tracker
from
.skip.tracker
import
SkipTrackerThroughPotals
,
use_skip_tracker
from
.types
import
(
from
.types
import
ACTIVATIONS_GRADS_QUEUE
,
PORTAL_QUEUE
,
SKIP_TENSOR_QUEUE
,
PipeMessage
,
TensorOrTensors
,
Tensors
ACTIVATIONS_GRADS_QUEUE
,
PORTAL_QUEUE
,
SKIP_TENSOR_QUEUE
,
PipelineStyle
,
PipeMessage
,
TensorOrTensors
,
Tensors
,
)
from
.worker
import
Task
from
.worker
import
Task
__all__
:
List
[
str
]
=
[]
__all__
:
List
[
str
]
=
[]
...
@@ -174,8 +165,8 @@ class MultiProcessPipeline:
...
@@ -174,8 +165,8 @@ class MultiProcessPipeline:
partitions
:
List
[
ModuleWrapper
],
partitions
:
List
[
ModuleWrapper
],
skip_layout
:
SkipLayout
,
skip_layout
:
SkipLayout
,
checkpoint_stop
:
int
,
checkpoint_stop
:
int
,
style
:
PipelineStyle
,
group
:
torch
.
distributed
.
ProcessGroup
,
group
:
torch
.
distributed
.
ProcessGroup
,
*
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
final_stage
:
bool
=
False
,
final_stage
:
bool
=
False
,
...
@@ -183,7 +174,6 @@ class MultiProcessPipeline:
...
@@ -183,7 +174,6 @@ class MultiProcessPipeline:
self
.
partitions
=
partitions
self
.
partitions
=
partitions
self
.
skip_layout
=
skip_layout
self
.
skip_layout
=
skip_layout
self
.
__checkpoint_stop
=
checkpoint_stop
self
.
__checkpoint_stop
=
checkpoint_stop
self
.
style
=
style
self
.
group
=
group
self
.
group
=
group
self
.
training
:
bool
self
.
training
:
bool
self
.
transport
=
MakeTransport
(
self
.
transport
=
MakeTransport
(
...
@@ -192,7 +182,6 @@ class MultiProcessPipeline:
...
@@ -192,7 +182,6 @@ class MultiProcessPipeline:
input_device
=
input_device
,
input_device
=
input_device
,
)
)
self
.
input_device
=
input_device
self
.
input_device
=
input_device
self
.
callcount
=
0
self
.
final_stage
=
final_stage
self
.
final_stage
=
final_stage
@
property
@
property
...
@@ -214,30 +203,22 @@ class MultiProcessPipeline:
...
@@ -214,30 +203,22 @@ class MultiProcessPipeline:
m
=
len
(
batches
)
m
=
len
(
batches
)
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
len
(
batches
)
)]
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
m
)]
if
self
.
style
is
PipelineStyle
.
MultiProcess
:
schedule
=
[(
i
,
self
.
group
.
rank
())
for
i
in
range
(
m
)]
schedule
=
[(
i
,
self
.
group
.
rank
())
for
i
in
range
(
m
)]
self
.
compute
(
batches
,
schedule
,
skip_trackers
)
elif
self
.
style
is
PipelineStyle
.
AsyncSchedule
:
for
i
,
j
in
schedule
:
rank
=
self
.
group
.
rank
()
assert
len
(
self
.
partitions
)
==
1
event_loop
=
AsyncEventLoop
(
partition
=
self
.
partitions
[
0
]
self
.
partitions
,
self
.
group
,
self
.
transport
,
self
.
training
,
self
.
checkpoint_stop
,
)
if
self
.
group
.
rank
()
!=
0
:
if
rank
==
0
and
not
self
.
final_stage
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event head"
)
event_loop
.
event_loop_head
(
batches
,
skip_trackers
,
event
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event head"
)
elif
self
.
final_stage
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event tail"
)
event_loop
.
event_loop_tail
(
batches
,
skip_trackers
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event tail"
)
else
:
else
:
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: entered event loop"
)
batch
=
batches
[
i
]
event_loop
.
event_loop
(
len
(
batches
),
skip_trackers
)
logging
.
debug
(
f
"
{
torch
.
distributed
.
get_rank
()
}
: exited event loop"
)
self
.
callcount
+=
1
task
=
create_task
(
self
.
checkpoint_stop
,
i
,
j
,
batch
,
partition
.
module
,
skip_trackers
)
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
def
get_batch_from_previous_stage
(
def
get_batch_from_previous_stage
(
self
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
batches
:
List
[
Batch
]
self
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
batches
:
List
[
Batch
]
...
@@ -299,7 +280,7 @@ class MultiProcessPipeline:
...
@@ -299,7 +280,7 @@ class MultiProcessPipeline:
rank
=
self
.
group
.
rank
()
rank
=
self
.
group
.
rank
()
if
self
.
style
is
PipelineStyle
.
MultiProcess
and
not
self
.
final_stage
:
if
not
self
.
final_stage
:
ranks
=
get_pipeline_parallel_ranks
()
ranks
=
get_pipeline_parallel_ranks
()
this_rank
=
torch
.
distributed
.
get_rank
()
this_rank
=
torch
.
distributed
.
get_rank
()
...
@@ -313,51 +294,6 @@ class MultiProcessPipeline:
...
@@ -313,51 +294,6 @@ class MultiProcessPipeline:
return
batch
return
batch
def
compute
(
self
,
batches
:
List
[
Batch
],
schedule
:
List
[
Tuple
[
int
,
int
]],
skip_trackers
:
List
[
SkipTrackerThroughPotals
]
)
->
None
:
"""Runs tasks with synchronization to copy streams."""
assert
self
.
style
is
PipelineStyle
.
MultiProcess
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for
i
,
j
in
schedule
:
assert
len
(
self
.
partitions
)
==
1
partition
=
self
.
partitions
[
0
]
if
self
.
group
.
rank
()
!=
0
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
else
:
batch
=
batches
[
i
]
task
=
create_task
(
self
.
checkpoint_stop
,
i
,
j
,
batch
,
partition
.
module
,
skip_trackers
)
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
def
send_portal_grad
(
self
,
ns_name
:
Tuple
[
Namespace
,
str
],
index
:
int
,
grad
:
TensorOrTensors
)
->
None
:
def
send_portal_grad
(
self
,
ns_name
:
Tuple
[
Namespace
,
str
],
index
:
int
,
grad
:
TensorOrTensors
)
->
None
:
dest
,
src
=
self
.
skip_layout
.
by_ns_name
.
get
(
ns_name
,
(
-
1
,
-
1
))
dest
,
src
=
self
.
skip_layout
.
by_ns_name
.
get
(
ns_name
,
(
-
1
,
-
1
))
if
dest
==
src
:
if
dest
==
src
:
...
@@ -385,9 +321,6 @@ class MultiProcessPipeline:
...
@@ -385,9 +321,6 @@ class MultiProcessPipeline:
return
result
return
result
def
back_helper
(
self
,
output
:
List
[
Batch
])
->
None
:
def
back_helper
(
self
,
output
:
List
[
Batch
])
->
None
:
if
self
.
style
==
PipelineStyle
.
AsyncSchedule
:
return
tensors
:
Tensors
tensors
:
Tensors
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
...
...
fairscale/nn/pipe/types.py
View file @
b5c3638f
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -34,11 +33,6 @@ class LazyModule:
...
@@ -34,11 +33,6 @@ class LazyModule:
return
self
.
function
()
return
self
.
function
()
class
PipelineStyle
(
Enum
):
MultiProcess
=
auto
()
AsyncSchedule
=
auto
()
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
PipeMessage
:
class
PipeMessage
:
src
:
int
src
:
int
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment