Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
8f77255b
Unverified
Commit
8f77255b
authored
Mar 02, 2021
by
msbaines
Committed by
GitHub
Mar 02, 2021
Browse files
[refactor] multiprocess_pipe: avoid unnecessary use of create_task and other cleanup (#456)
parent
d2924670
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
99 additions
and
105 deletions
+99
-105
fairscale/nn/pipe/async_schedule.py
fairscale/nn/pipe/async_schedule.py
+52
-3
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+2
-2
fairscale/nn/pipe/multiprocess_pipeline.py
fairscale/nn/pipe/multiprocess_pipeline.py
+45
-100
No files found.
fairscale/nn/pipe/async_schedule.py
View file @
8f77255b
...
...
@@ -11,15 +11,64 @@ from typing import Dict, Iterable, List, Optional, Tuple
import
torch
from
torch
import
Tensor
,
nn
from
torch.autograd.profiler
import
record_function
from
torch.distributed
import
ProcessGroup
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_ranks
from
.checkpoint
import
Checkpointing
from
.messages
import
Transport
from
.microbatch
import
Batch
from
.multiprocess_pipeline
import
create_task
from
.skip.tracker
import
SkipTrackerThroughPotals
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
Tensors
from
.skip.tracker
import
SkipTrackerThroughPotals
,
use_skip_tracker
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
TensorOrTensors
,
Tensors
from
.worker
import
Task
def
create_task
(
checkpoint_stop
:
int
,
chunk_id
:
int
,
part_id
:
int
,
batch
:
Batch
,
partition
:
nn
.
Sequential
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
)
->
Task
:
# Determine whether checkpointing or not.
if
chunk_id
<
checkpoint_stop
:
def
function
(
input
:
TensorOrTensors
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
chunk_id
],
chunk_id
:
int
=
chunk_id
,
part_id
:
int
=
part_id
,
)
->
TensorOrTensors
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
ret
=
partition
(
input
)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert
type
(
ret
)
is
not
list
,
"Only Tensor or Tuple of Tensor output is supported"
return
ret
chk
=
Checkpointing
(
function
,
batch
)
task
=
Task
(
None
,
compute
=
chk
.
checkpoint
,
finalize
=
chk
.
recompute
)
del
function
,
chk
# TODO(tom) maybe remove
else
:
def
compute
(
batch
:
Batch
=
batch
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
chunk_id
],
chunk_id
:
int
=
chunk_id
,
part_id
:
int
=
part_id
,
)
->
Batch
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
return
batch
.
call
(
partition
)
task
=
Task
(
None
,
compute
=
compute
,
finalize
=
None
)
del
compute
# TODO(tom) maybe remove
return
task
@
dataclass
(
frozen
=
True
)
...
...
fairscale/nn/pipe/multiprocess_pipe.py
View file @
8f77255b
...
...
@@ -256,7 +256,7 @@ class MultiProcessPipe(Module):
"""Iterates over children of the underlying sequential module."""
return
self
.
partition
.
__iter__
()
def
forward
(
self
,
input
:
TensorOrTensors
,
*
,
event
=
None
)
->
TensorOrTensors
:
# type: ignore
def
forward
(
self
,
input
:
TensorOrTensors
)
->
TensorOrTensors
:
# type: ignore
""":class:`MultiProcessPipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
...
...
@@ -284,7 +284,7 @@ class MultiProcessPipe(Module):
# Run pipeline parallelism.
with
self
.
lock
:
self
.
pipeline
.
run
(
self
.
training
,
batches
,
event
)
self
.
pipeline
.
run
(
self
.
training
,
batches
)
if
self
.
final_stage
:
# Merge the micro-batches into one mini-batch.
...
...
fairscale/nn/pipe/multiprocess_pipeline.py
View file @
8f77255b
...
...
@@ -20,7 +20,6 @@
import
os
from
queue
import
Empty
as
QueueEmpty
from
queue
import
Queue
from
threading
import
Event
from
types
import
TracebackType
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
...
...
@@ -39,6 +38,16 @@ from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from
.types
import
ACTIVATIONS_GRADS_QUEUE
,
PORTAL_QUEUE
,
SKIP_TENSOR_QUEUE
,
PipeMessage
,
TensorOrTensors
,
Tensors
from
.worker
import
Task
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if
TYPE_CHECKING
:
InQueue
=
Queue
[
Optional
[
Task
]]
OutQueue
=
Queue
[
Tuple
[
bool
,
Union
[
Tuple
[
Task
,
Batch
],
ExcInfo
,
None
]]]
else
:
InQueue
=
Queue
OutQueue
=
Queue
__all__
:
List
[
str
]
=
[]
ExcInfo
=
Tuple
[
Type
[
BaseException
],
BaseException
,
TracebackType
]
...
...
@@ -49,8 +58,10 @@ class SendOperator(torch.autograd.Function):
@
staticmethod
# type: ignore
def
forward
(
ctx
,
src_rank
,
dst_rank
,
transport
:
Transport
,
input
:
List
[
Tensor
],
index
:
int
)
->
Tensors
:
assert
src_rank
==
torch
.
distributed
.
get_rank
()
def
forward
(
ctx
,
transport
:
Transport
,
input
:
List
[
Tensor
],
index
:
int
)
->
Tensors
:
ranks
=
get_pipeline_parallel_ranks
()
src_rank
=
torch
.
distributed
.
get_rank
()
dst_rank
=
ranks
[
ranks
.
index
(
src_rank
)
+
1
]
transport
.
send_message
(
PipeMessage
(
src_rank
,
dst_rank
,
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
index
,
tensors
=
tuple
(
input
)),
...
...
@@ -68,8 +79,7 @@ class RecvOperator(torch.autograd.Function):
@
staticmethod
# type: ignore
def
forward
(
ctx
,
dst_rank
:
int
,
tensor
:
Tensor
,
transport
:
Transport
,
index
:
int
)
->
Tensors
:
assert
dst_rank
==
torch
.
distributed
.
get_rank
()
def
forward
(
ctx
,
tensor
:
Tensor
,
transport
:
Transport
,
index
:
int
)
->
Tensors
:
ctx
.
transport
=
transport
ctx
.
index
=
index
...
...
@@ -86,74 +96,12 @@ class RecvOperator(torch.autograd.Function):
# type: ignore
def
backward
(
ctx
,
*
grad
:
Tensor
,)
->
Tuple
[
Optional
[
Tensor
],
...]:
ranks
=
get_pipeline_parallel_ranks
()
this_rank
=
torch
.
distributed
.
get_rank
()
src_rank
=
torch
.
distributed
.
get_rank
()
dst_rank
=
ranks
[
ranks
.
index
(
src_rank
)
-
1
]
ctx
.
transport
.
send_message
(
PipeMessage
(
this_rank
,
ranks
[
ranks
.
index
(
this_rank
)
-
1
],
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
ctx
.
index
,
tensors
=
tuple
(
grad
),
),
PipeMessage
(
src_rank
,
dst_rank
,
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
ctx
.
index
,
tensors
=
tuple
(
grad
),),
)
return
(
None
,
None
,
None
,
None
,
None
)
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if
TYPE_CHECKING
:
InQueue
=
Queue
[
Optional
[
"Task"
]]
OutQueue
=
Queue
[
Tuple
[
bool
,
Union
[
Tuple
[
"Task"
,
Batch
],
ExcInfo
,
None
]]]
else
:
InQueue
=
Queue
OutQueue
=
Queue
def
create_task
(
checkpoint_stop
:
int
,
i
:
int
,
j
:
int
,
batch
:
Batch
,
partition
:
nn
.
Sequential
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
)
->
Task
:
# Determine whether checkpointing or not.
if
i
<
checkpoint_stop
:
def
function
(
input
:
TensorOrTensors
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
i
],
chunk_id
:
int
=
i
,
part_id
:
int
=
j
,
)
->
TensorOrTensors
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
ret
=
partition
(
input
)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert
type
(
ret
)
is
not
list
,
"Only Tensor or Tuple of Tensor output is supported"
return
ret
chk
=
Checkpointing
(
function
,
batch
)
task
=
Task
(
None
,
compute
=
chk
.
checkpoint
,
finalize
=
chk
.
recompute
)
del
function
,
chk
# TODO(tom) maybe remove
else
:
def
compute
(
batch
:
Batch
=
batch
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
i
],
chunk_id
:
int
=
i
,
part_id
:
int
=
j
,
)
->
Batch
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
return
batch
.
call
(
partition
)
task
=
Task
(
None
,
compute
=
compute
,
finalize
=
None
)
del
compute
# TODO(tom) maybe remove
return
task
return
(
None
,
None
,
None
,
None
)
class
MultiProcessPipeline
:
...
...
@@ -191,7 +139,7 @@ class MultiProcessPipeline:
return
0
return
self
.
__checkpoint_stop
def
run
(
self
,
training
:
bool
,
batches
:
List
[
Batch
]
,
event
:
Optional
[
Event
]
)
->
None
:
def
run
(
self
,
training
:
bool
,
batches
:
List
[
Batch
])
->
None
:
"""Runs pipeline parallelism.
...
...
@@ -204,24 +152,39 @@ class MultiProcessPipeline:
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
m
)]
schedule
=
[(
i
,
self
.
group
.
rank
()
)
for
i
in
range
(
m
)]
rank
=
self
.
group
.
rank
()
for
i
,
j
in
schedule
:
if
self
.
group
.
rank
()
!=
0
:
for
i
in
range
(
m
)
:
if
rank
!=
0
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
else
:
batch
=
batches
[
i
]
task
=
create_task
(
self
.
checkpoint_stop
,
i
,
j
,
batch
,
self
.
partition
,
skip_trackers
)
with
use_skip_tracker
(
skip_trackers
[
i
]),
record_function
(
"chunk%d-part%d"
%
(
i
,
rank
)):
if
i
<
self
.
checkpoint_stop
:
chk
=
Checkpointing
(
self
.
partition
,
batch
)
batch
=
chk
.
checkpoint
()
else
:
batch
=
batch
.
call
(
self
.
partition
)
if
not
self
.
final_stage
:
self
.
send_skip_tensors
(
batch
,
i
,
skip_trackers
)
SendOperator
.
apply
(
self
.
transport
,
[
*
batch
],
i
)
for
portal
in
skip_trackers
[
i
].
portals
.
values
():
portal
.
pipeline
=
self
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
if
i
<
self
.
checkpoint_stop
:
chk
.
recompute
(
batch
)
batches
[
i
]
=
batch
def
get_batch_from_previous_stage
(
self
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
batches
:
List
[
Batch
]
)
->
Batch
:
phony
=
torch
.
empty
(
0
,
device
=
self
.
input_device
,
requires_grad
=
True
)
result
=
RecvOperator
.
apply
(
torch
.
distributed
.
get_rank
(),
phony
,
self
.
transport
,
i
)
result
=
RecvOperator
.
apply
(
phony
,
self
.
transport
,
i
)
if
len
(
result
)
==
1
:
batch
=
Batch
(
result
[
0
],
i
)
else
:
...
...
@@ -231,9 +194,10 @@ class MultiProcessPipeline:
return
batch
def
send_skip_tensors
(
self
,
this_rank
:
int
,
ranks
:
List
[
int
],
batch
:
Batch
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
]
)
->
None
:
def
send_skip_tensors
(
self
,
batch
:
Batch
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
])
->
None
:
ranks
=
get_pipeline_parallel_ranks
()
this_rank
=
torch
.
distributed
.
get_rank
()
for
next_j
,
ns
,
name
in
self
.
skip_layout
.
copy_policy_by_src
(
self
.
group
.
rank
()):
life
=
skip_trackers
[
i
].
portals
[(
ns
,
name
)].
tensor_life
loaded
=
skip_trackers
[
i
].
load
(
batch
,
ns
,
name
)
...
...
@@ -271,25 +235,6 @@ class MultiProcessPipeline:
except
QueueEmpty
:
break
def
execute_task
(
self
,
task
:
Task
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
])
->
Batch
:
batch
=
task
.
compute
()
rank
=
self
.
group
.
rank
()
if
not
self
.
final_stage
:
ranks
=
get_pipeline_parallel_ranks
()
this_rank
=
torch
.
distributed
.
get_rank
()
self
.
send_skip_tensors
(
this_rank
,
ranks
,
batch
,
i
,
skip_trackers
)
SendOperator
.
apply
(
this_rank
,
ranks
[
ranks
.
index
(
this_rank
)
+
1
],
self
.
transport
,
[
*
batch
],
i
)
for
portal
in
skip_trackers
[
i
].
portals
.
values
():
portal
.
pipeline
=
self
task
.
finalize
(
batch
)
return
batch
def
send_portal_grad
(
self
,
ns_name
:
Tuple
[
Namespace
,
str
],
index
:
int
,
grad
:
TensorOrTensors
)
->
None
:
dest
,
src
=
self
.
skip_layout
.
by_ns_name
.
get
(
ns_name
,
(
-
1
,
-
1
))
if
dest
==
src
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment