Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
8f77255b
Unverified
Commit
8f77255b
authored
Mar 02, 2021
by
msbaines
Committed by
GitHub
Mar 02, 2021
Browse files
[refactor] multiprocess_pipe: avoid unnecessary use of create_task and other cleanup (#456)
parent
d2924670
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
99 additions
and
105 deletions
+99
-105
fairscale/nn/pipe/async_schedule.py
fairscale/nn/pipe/async_schedule.py
+52
-3
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+2
-2
fairscale/nn/pipe/multiprocess_pipeline.py
fairscale/nn/pipe/multiprocess_pipeline.py
+45
-100
No files found.
fairscale/nn/pipe/async_schedule.py
View file @
8f77255b
...
...
@@ -11,15 +11,64 @@ from typing import Dict, Iterable, List, Optional, Tuple
import
torch
from
torch
import
Tensor
,
nn
from
torch.autograd.profiler
import
record_function
from
torch.distributed
import
ProcessGroup
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_ranks
from
.checkpoint
import
Checkpointing
from
.messages
import
Transport
from
.microbatch
import
Batch
from
.multiprocess_pipeline
import
create_task
from
.skip.tracker
import
SkipTrackerThroughPotals
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
Tensors
from
.skip.tracker
import
SkipTrackerThroughPotals
,
use_skip_tracker
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
TensorOrTensors
,
Tensors
from
.worker
import
Task
def
create_task
(
checkpoint_stop
:
int
,
chunk_id
:
int
,
part_id
:
int
,
batch
:
Batch
,
partition
:
nn
.
Sequential
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
)
->
Task
:
# Determine whether checkpointing or not.
if
chunk_id
<
checkpoint_stop
:
def
function
(
input
:
TensorOrTensors
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
chunk_id
],
chunk_id
:
int
=
chunk_id
,
part_id
:
int
=
part_id
,
)
->
TensorOrTensors
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
ret
=
partition
(
input
)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert
type
(
ret
)
is
not
list
,
"Only Tensor or Tuple of Tensor output is supported"
return
ret
chk
=
Checkpointing
(
function
,
batch
)
task
=
Task
(
None
,
compute
=
chk
.
checkpoint
,
finalize
=
chk
.
recompute
)
del
function
,
chk
# TODO(tom) maybe remove
else
:
def
compute
(
batch
:
Batch
=
batch
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
chunk_id
],
chunk_id
:
int
=
chunk_id
,
part_id
:
int
=
part_id
,
)
->
Batch
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
return
batch
.
call
(
partition
)
task
=
Task
(
None
,
compute
=
compute
,
finalize
=
None
)
del
compute
# TODO(tom) maybe remove
return
task
@
dataclass
(
frozen
=
True
)
...
...
fairscale/nn/pipe/multiprocess_pipe.py
View file @
8f77255b
...
...
@@ -256,7 +256,7 @@ class MultiProcessPipe(Module):
"""Iterates over children of the underlying sequential module."""
return
self
.
partition
.
__iter__
()
def
forward
(
self
,
input
:
TensorOrTensors
,
*
,
event
=
None
)
->
TensorOrTensors
:
# type: ignore
def
forward
(
self
,
input
:
TensorOrTensors
)
->
TensorOrTensors
:
# type: ignore
""":class:`MultiProcessPipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
...
...
@@ -284,7 +284,7 @@ class MultiProcessPipe(Module):
# Run pipeline parallelism.
with
self
.
lock
:
self
.
pipeline
.
run
(
self
.
training
,
batches
,
event
)
self
.
pipeline
.
run
(
self
.
training
,
batches
)
if
self
.
final_stage
:
# Merge the micro-batches into one mini-batch.
...
...
fairscale/nn/pipe/multiprocess_pipeline.py
View file @
8f77255b
...
...
@@ -20,7 +20,6 @@
import
os
from
queue
import
Empty
as
QueueEmpty
from
queue
import
Queue
from
threading
import
Event
from
types
import
TracebackType
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
...
...
@@ -39,6 +38,16 @@ from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from
.types
import
ACTIVATIONS_GRADS_QUEUE
,
PORTAL_QUEUE
,
SKIP_TENSOR_QUEUE
,
PipeMessage
,
TensorOrTensors
,
Tensors
from
.worker
import
Task
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if
TYPE_CHECKING
:
InQueue
=
Queue
[
Optional
[
Task
]]
OutQueue
=
Queue
[
Tuple
[
bool
,
Union
[
Tuple
[
Task
,
Batch
],
ExcInfo
,
None
]]]
else
:
InQueue
=
Queue
OutQueue
=
Queue
__all__
:
List
[
str
]
=
[]
ExcInfo
=
Tuple
[
Type
[
BaseException
],
BaseException
,
TracebackType
]
...
...
@@ -49,8 +58,10 @@ class SendOperator(torch.autograd.Function):
@
staticmethod
# type: ignore
def
forward
(
ctx
,
src_rank
,
dst_rank
,
transport
:
Transport
,
input
:
List
[
Tensor
],
index
:
int
)
->
Tensors
:
assert
src_rank
==
torch
.
distributed
.
get_rank
()
def
forward
(
ctx
,
transport
:
Transport
,
input
:
List
[
Tensor
],
index
:
int
)
->
Tensors
:
ranks
=
get_pipeline_parallel_ranks
()
src_rank
=
torch
.
distributed
.
get_rank
()
dst_rank
=
ranks
[
ranks
.
index
(
src_rank
)
+
1
]
transport
.
send_message
(
PipeMessage
(
src_rank
,
dst_rank
,
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
index
,
tensors
=
tuple
(
input
)),
...
...
@@ -68,8 +79,7 @@ class RecvOperator(torch.autograd.Function):
@
staticmethod
# type: ignore
def
forward
(
ctx
,
dst_rank
:
int
,
tensor
:
Tensor
,
transport
:
Transport
,
index
:
int
)
->
Tensors
:
assert
dst_rank
==
torch
.
distributed
.
get_rank
()
def
forward
(
ctx
,
tensor
:
Tensor
,
transport
:
Transport
,
index
:
int
)
->
Tensors
:
ctx
.
transport
=
transport
ctx
.
index
=
index
...
...
@@ -86,74 +96,12 @@ class RecvOperator(torch.autograd.Function):
# type: ignore
def
backward
(
ctx
,
*
grad
:
Tensor
,)
->
Tuple
[
Optional
[
Tensor
],
...]:
ranks
=
get_pipeline_parallel_ranks
()
this_rank
=
torch
.
distributed
.
get_rank
()
src_rank
=
torch
.
distributed
.
get_rank
()
dst_rank
=
ranks
[
ranks
.
index
(
src_rank
)
-
1
]
ctx
.
transport
.
send_message
(
PipeMessage
(
this_rank
,
ranks
[
ranks
.
index
(
this_rank
)
-
1
],
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
ctx
.
index
,
tensors
=
tuple
(
grad
),
),
PipeMessage
(
src_rank
,
dst_rank
,
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
ctx
.
index
,
tensors
=
tuple
(
grad
),),
)
return
(
None
,
None
,
None
,
None
,
None
)
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if
TYPE_CHECKING
:
InQueue
=
Queue
[
Optional
[
"Task"
]]
OutQueue
=
Queue
[
Tuple
[
bool
,
Union
[
Tuple
[
"Task"
,
Batch
],
ExcInfo
,
None
]]]
else
:
InQueue
=
Queue
OutQueue
=
Queue
def
create_task
(
checkpoint_stop
:
int
,
i
:
int
,
j
:
int
,
batch
:
Batch
,
partition
:
nn
.
Sequential
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
)
->
Task
:
# Determine whether checkpointing or not.
if
i
<
checkpoint_stop
:
def
function
(
input
:
TensorOrTensors
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
i
],
chunk_id
:
int
=
i
,
part_id
:
int
=
j
,
)
->
TensorOrTensors
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
ret
=
partition
(
input
)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert
type
(
ret
)
is
not
list
,
"Only Tensor or Tuple of Tensor output is supported"
return
ret
chk
=
Checkpointing
(
function
,
batch
)
task
=
Task
(
None
,
compute
=
chk
.
checkpoint
,
finalize
=
chk
.
recompute
)
del
function
,
chk
# TODO(tom) maybe remove
else
:
def
compute
(
batch
:
Batch
=
batch
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
i
],
chunk_id
:
int
=
i
,
part_id
:
int
=
j
,
)
->
Batch
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
return
batch
.
call
(
partition
)
task
=
Task
(
None
,
compute
=
compute
,
finalize
=
None
)
del
compute
# TODO(tom) maybe remove
return
task
return
(
None
,
None
,
None
,
None
)
class
MultiProcessPipeline
:
...
...
@@ -191,7 +139,7 @@ class MultiProcessPipeline:
return
0
return
self
.
__checkpoint_stop
def
run
(
self
,
training
:
bool
,
batches
:
List
[
Batch
]
,
event
:
Optional
[
Event
]
)
->
None
:
def
run
(
self
,
training
:
bool
,
batches
:
List
[
Batch
])
->
None
:
"""Runs pipeline parallelism.
...
...
@@ -204,24 +152,39 @@ class MultiProcessPipeline:
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
m
)]
schedule
=
[(
i
,
self
.
group
.
rank
()
)
for
i
in
range
(
m
)]
rank
=
self
.
group
.
rank
()
for
i
,
j
in
schedule
:
if
self
.
group
.
rank
()
!=
0
:
for
i
in
range
(
m
)
:
if
rank
!=
0
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
else
:
batch
=
batches
[
i
]
task
=
create_task
(
self
.
checkpoint_stop
,
i
,
j
,
batch
,
self
.
partition
,
skip_trackers
)
with
use_skip_tracker
(
skip_trackers
[
i
]),
record_function
(
"chunk%d-part%d"
%
(
i
,
rank
)):
if
i
<
self
.
checkpoint_stop
:
chk
=
Checkpointing
(
self
.
partition
,
batch
)
batch
=
chk
.
checkpoint
()
else
:
batch
=
batch
.
call
(
self
.
partition
)
if
not
self
.
final_stage
:
self
.
send_skip_tensors
(
batch
,
i
,
skip_trackers
)
SendOperator
.
apply
(
self
.
transport
,
[
*
batch
],
i
)
for
portal
in
skip_trackers
[
i
].
portals
.
values
():
portal
.
pipeline
=
self
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
if
i
<
self
.
checkpoint_stop
:
chk
.
recompute
(
batch
)
batches
[
i
]
=
batch
def
get_batch_from_previous_stage
(
self
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
batches
:
List
[
Batch
]
)
->
Batch
:
phony
=
torch
.
empty
(
0
,
device
=
self
.
input_device
,
requires_grad
=
True
)
result
=
RecvOperator
.
apply
(
torch
.
distributed
.
get_rank
(),
phony
,
self
.
transport
,
i
)
result
=
RecvOperator
.
apply
(
phony
,
self
.
transport
,
i
)
if
len
(
result
)
==
1
:
batch
=
Batch
(
result
[
0
],
i
)
else
:
...
...
@@ -231,9 +194,10 @@ class MultiProcessPipeline:
return
batch
def
send_skip_tensors
(
self
,
this_rank
:
int
,
ranks
:
List
[
int
],
batch
:
Batch
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
]
)
->
None
:
def
send_skip_tensors
(
self
,
batch
:
Batch
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
])
->
None
:
ranks
=
get_pipeline_parallel_ranks
()
this_rank
=
torch
.
distributed
.
get_rank
()
for
next_j
,
ns
,
name
in
self
.
skip_layout
.
copy_policy_by_src
(
self
.
group
.
rank
()):
life
=
skip_trackers
[
i
].
portals
[(
ns
,
name
)].
tensor_life
loaded
=
skip_trackers
[
i
].
load
(
batch
,
ns
,
name
)
...
...
@@ -271,25 +235,6 @@ class MultiProcessPipeline:
except
QueueEmpty
:
break
def
execute_task
(
self
,
task
:
Task
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
])
->
Batch
:
batch
=
task
.
compute
()
rank
=
self
.
group
.
rank
()
if
not
self
.
final_stage
:
ranks
=
get_pipeline_parallel_ranks
()
this_rank
=
torch
.
distributed
.
get_rank
()
self
.
send_skip_tensors
(
this_rank
,
ranks
,
batch
,
i
,
skip_trackers
)
SendOperator
.
apply
(
this_rank
,
ranks
[
ranks
.
index
(
this_rank
)
+
1
],
self
.
transport
,
[
*
batch
],
i
)
for
portal
in
skip_trackers
[
i
].
portals
.
values
():
portal
.
pipeline
=
self
task
.
finalize
(
batch
)
return
batch
def
send_portal_grad
(
self
,
ns_name
:
Tuple
[
Namespace
,
str
],
index
:
int
,
grad
:
TensorOrTensors
)
->
None
:
dest
,
src
=
self
.
skip_layout
.
by_ns_name
.
get
(
ns_name
,
(
-
1
,
-
1
))
if
dest
==
src
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment