Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
8f77255b
Unverified
Commit
8f77255b
authored
Mar 02, 2021
by
msbaines
Committed by
GitHub
Mar 02, 2021
Browse files
[refactor] multiprocess_pipe: avoid unnecessary use of create_task and other cleanup (#456)
parent
d2924670
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
99 additions
and
105 deletions
+99
-105
fairscale/nn/pipe/async_schedule.py
fairscale/nn/pipe/async_schedule.py
+52
-3
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+2
-2
fairscale/nn/pipe/multiprocess_pipeline.py
fairscale/nn/pipe/multiprocess_pipeline.py
+45
-100
No files found.
fairscale/nn/pipe/async_schedule.py
View file @
8f77255b
...
@@ -11,15 +11,64 @@ from typing import Dict, Iterable, List, Optional, Tuple
...
@@ -11,15 +11,64 @@ from typing import Dict, Iterable, List, Optional, Tuple
import
torch
import
torch
from
torch
import
Tensor
,
nn
from
torch
import
Tensor
,
nn
from
torch.autograd.profiler
import
record_function
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_ranks
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_ranks
from
.checkpoint
import
Checkpointing
from
.messages
import
Transport
from
.messages
import
Transport
from
.microbatch
import
Batch
from
.microbatch
import
Batch
from
.multiprocess_pipeline
import
create_task
from
.skip.tracker
import
SkipTrackerThroughPotals
,
use_skip_tracker
from
.skip.tracker
import
SkipTrackerThroughPotals
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
TensorOrTensors
,
Tensors
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
Tensors
from
.worker
import
Task
def
create_task
(
checkpoint_stop
:
int
,
chunk_id
:
int
,
part_id
:
int
,
batch
:
Batch
,
partition
:
nn
.
Sequential
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
)
->
Task
:
# Determine whether checkpointing or not.
if
chunk_id
<
checkpoint_stop
:
def
function
(
input
:
TensorOrTensors
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
chunk_id
],
chunk_id
:
int
=
chunk_id
,
part_id
:
int
=
part_id
,
)
->
TensorOrTensors
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
ret
=
partition
(
input
)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert
type
(
ret
)
is
not
list
,
"Only Tensor or Tuple of Tensor output is supported"
return
ret
chk
=
Checkpointing
(
function
,
batch
)
task
=
Task
(
None
,
compute
=
chk
.
checkpoint
,
finalize
=
chk
.
recompute
)
del
function
,
chk
# TODO(tom) maybe remove
else
:
def
compute
(
batch
:
Batch
=
batch
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
chunk_id
],
chunk_id
:
int
=
chunk_id
,
part_id
:
int
=
part_id
,
)
->
Batch
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
return
batch
.
call
(
partition
)
task
=
Task
(
None
,
compute
=
compute
,
finalize
=
None
)
del
compute
# TODO(tom) maybe remove
return
task
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
...
fairscale/nn/pipe/multiprocess_pipe.py
View file @
8f77255b
...
@@ -256,7 +256,7 @@ class MultiProcessPipe(Module):
...
@@ -256,7 +256,7 @@ class MultiProcessPipe(Module):
"""Iterates over children of the underlying sequential module."""
"""Iterates over children of the underlying sequential module."""
return
self
.
partition
.
__iter__
()
return
self
.
partition
.
__iter__
()
def
forward
(
self
,
input
:
TensorOrTensors
,
*
,
event
=
None
)
->
TensorOrTensors
:
# type: ignore
def
forward
(
self
,
input
:
TensorOrTensors
)
->
TensorOrTensors
:
# type: ignore
""":class:`MultiProcessPipe` is a fairly transparent module wrapper. It doesn't
""":class:`MultiProcessPipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
there's type restriction. Input and output have to be a
...
@@ -284,7 +284,7 @@ class MultiProcessPipe(Module):
...
@@ -284,7 +284,7 @@ class MultiProcessPipe(Module):
# Run pipeline parallelism.
# Run pipeline parallelism.
with
self
.
lock
:
with
self
.
lock
:
self
.
pipeline
.
run
(
self
.
training
,
batches
,
event
)
self
.
pipeline
.
run
(
self
.
training
,
batches
)
if
self
.
final_stage
:
if
self
.
final_stage
:
# Merge the micro-batches into one mini-batch.
# Merge the micro-batches into one mini-batch.
...
...
fairscale/nn/pipe/multiprocess_pipeline.py
View file @
8f77255b
...
@@ -20,7 +20,6 @@
...
@@ -20,7 +20,6 @@
import
os
import
os
from
queue
import
Empty
as
QueueEmpty
from
queue
import
Empty
as
QueueEmpty
from
queue
import
Queue
from
queue
import
Queue
from
threading
import
Event
from
types
import
TracebackType
from
types
import
TracebackType
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
...
@@ -39,6 +38,16 @@ from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
...
@@ -39,6 +38,16 @@ from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from
.types
import
ACTIVATIONS_GRADS_QUEUE
,
PORTAL_QUEUE
,
SKIP_TENSOR_QUEUE
,
PipeMessage
,
TensorOrTensors
,
Tensors
from
.types
import
ACTIVATIONS_GRADS_QUEUE
,
PORTAL_QUEUE
,
SKIP_TENSOR_QUEUE
,
PipeMessage
,
TensorOrTensors
,
Tensors
from
.worker
import
Task
from
.worker
import
Task
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if
TYPE_CHECKING
:
InQueue
=
Queue
[
Optional
[
Task
]]
OutQueue
=
Queue
[
Tuple
[
bool
,
Union
[
Tuple
[
Task
,
Batch
],
ExcInfo
,
None
]]]
else
:
InQueue
=
Queue
OutQueue
=
Queue
__all__
:
List
[
str
]
=
[]
__all__
:
List
[
str
]
=
[]
ExcInfo
=
Tuple
[
Type
[
BaseException
],
BaseException
,
TracebackType
]
ExcInfo
=
Tuple
[
Type
[
BaseException
],
BaseException
,
TracebackType
]
...
@@ -49,8 +58,10 @@ class SendOperator(torch.autograd.Function):
...
@@ -49,8 +58,10 @@ class SendOperator(torch.autograd.Function):
@
staticmethod
@
staticmethod
# type: ignore
# type: ignore
def
forward
(
ctx
,
src_rank
,
dst_rank
,
transport
:
Transport
,
input
:
List
[
Tensor
],
index
:
int
)
->
Tensors
:
def
forward
(
ctx
,
transport
:
Transport
,
input
:
List
[
Tensor
],
index
:
int
)
->
Tensors
:
assert
src_rank
==
torch
.
distributed
.
get_rank
()
ranks
=
get_pipeline_parallel_ranks
()
src_rank
=
torch
.
distributed
.
get_rank
()
dst_rank
=
ranks
[
ranks
.
index
(
src_rank
)
+
1
]
transport
.
send_message
(
transport
.
send_message
(
PipeMessage
(
src_rank
,
dst_rank
,
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
index
,
tensors
=
tuple
(
input
)),
PipeMessage
(
src_rank
,
dst_rank
,
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
index
,
tensors
=
tuple
(
input
)),
...
@@ -68,8 +79,7 @@ class RecvOperator(torch.autograd.Function):
...
@@ -68,8 +79,7 @@ class RecvOperator(torch.autograd.Function):
@
staticmethod
@
staticmethod
# type: ignore
# type: ignore
def
forward
(
ctx
,
dst_rank
:
int
,
tensor
:
Tensor
,
transport
:
Transport
,
index
:
int
)
->
Tensors
:
def
forward
(
ctx
,
tensor
:
Tensor
,
transport
:
Transport
,
index
:
int
)
->
Tensors
:
assert
dst_rank
==
torch
.
distributed
.
get_rank
()
ctx
.
transport
=
transport
ctx
.
transport
=
transport
ctx
.
index
=
index
ctx
.
index
=
index
...
@@ -86,74 +96,12 @@ class RecvOperator(torch.autograd.Function):
...
@@ -86,74 +96,12 @@ class RecvOperator(torch.autograd.Function):
# type: ignore
# type: ignore
def
backward
(
ctx
,
*
grad
:
Tensor
,)
->
Tuple
[
Optional
[
Tensor
],
...]:
def
backward
(
ctx
,
*
grad
:
Tensor
,)
->
Tuple
[
Optional
[
Tensor
],
...]:
ranks
=
get_pipeline_parallel_ranks
()
ranks
=
get_pipeline_parallel_ranks
()
this_rank
=
torch
.
distributed
.
get_rank
()
src_rank
=
torch
.
distributed
.
get_rank
()
dst_rank
=
ranks
[
ranks
.
index
(
src_rank
)
-
1
]
ctx
.
transport
.
send_message
(
ctx
.
transport
.
send_message
(
PipeMessage
(
PipeMessage
(
src_rank
,
dst_rank
,
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
ctx
.
index
,
tensors
=
tuple
(
grad
),),
this_rank
,
ranks
[
ranks
.
index
(
this_rank
)
-
1
],
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
ctx
.
index
,
tensors
=
tuple
(
grad
),
),
)
)
return
(
None
,
None
,
None
,
None
,
None
)
return
(
None
,
None
,
None
,
None
)
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if
TYPE_CHECKING
:
InQueue
=
Queue
[
Optional
[
"Task"
]]
OutQueue
=
Queue
[
Tuple
[
bool
,
Union
[
Tuple
[
"Task"
,
Batch
],
ExcInfo
,
None
]]]
else
:
InQueue
=
Queue
OutQueue
=
Queue
def
create_task
(
checkpoint_stop
:
int
,
i
:
int
,
j
:
int
,
batch
:
Batch
,
partition
:
nn
.
Sequential
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
)
->
Task
:
# Determine whether checkpointing or not.
if
i
<
checkpoint_stop
:
def
function
(
input
:
TensorOrTensors
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
i
],
chunk_id
:
int
=
i
,
part_id
:
int
=
j
,
)
->
TensorOrTensors
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
ret
=
partition
(
input
)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert
type
(
ret
)
is
not
list
,
"Only Tensor or Tuple of Tensor output is supported"
return
ret
chk
=
Checkpointing
(
function
,
batch
)
task
=
Task
(
None
,
compute
=
chk
.
checkpoint
,
finalize
=
chk
.
recompute
)
del
function
,
chk
# TODO(tom) maybe remove
else
:
def
compute
(
batch
:
Batch
=
batch
,
partition
:
nn
.
Sequential
=
partition
,
skip_tracker
:
SkipTrackerThroughPotals
=
skip_trackers
[
i
],
chunk_id
:
int
=
i
,
part_id
:
int
=
j
,
)
->
Batch
:
with
use_skip_tracker
(
skip_tracker
),
record_function
(
"chunk%d-part%d"
%
(
chunk_id
,
part_id
)):
return
batch
.
call
(
partition
)
task
=
Task
(
None
,
compute
=
compute
,
finalize
=
None
)
del
compute
# TODO(tom) maybe remove
return
task
class
MultiProcessPipeline
:
class
MultiProcessPipeline
:
...
@@ -191,7 +139,7 @@ class MultiProcessPipeline:
...
@@ -191,7 +139,7 @@ class MultiProcessPipeline:
return
0
return
0
return
self
.
__checkpoint_stop
return
self
.
__checkpoint_stop
def
run
(
self
,
training
:
bool
,
batches
:
List
[
Batch
]
,
event
:
Optional
[
Event
]
)
->
None
:
def
run
(
self
,
training
:
bool
,
batches
:
List
[
Batch
])
->
None
:
"""Runs pipeline parallelism.
"""Runs pipeline parallelism.
...
@@ -204,24 +152,39 @@ class MultiProcessPipeline:
...
@@ -204,24 +152,39 @@ class MultiProcessPipeline:
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
m
)]
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
m
)]
schedule
=
[(
i
,
self
.
group
.
rank
()
)
for
i
in
range
(
m
)]
rank
=
self
.
group
.
rank
()
for
i
,
j
in
schedule
:
for
i
in
range
(
m
)
:
if
self
.
group
.
rank
()
!=
0
:
if
rank
!=
0
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
else
:
else
:
batch
=
batches
[
i
]
batch
=
batches
[
i
]
task
=
create_task
(
self
.
checkpoint_stop
,
i
,
j
,
batch
,
self
.
partition
,
skip_trackers
)
with
use_skip_tracker
(
skip_trackers
[
i
]),
record_function
(
"chunk%d-part%d"
%
(
i
,
rank
)):
if
i
<
self
.
checkpoint_stop
:
chk
=
Checkpointing
(
self
.
partition
,
batch
)
batch
=
chk
.
checkpoint
()
else
:
batch
=
batch
.
call
(
self
.
partition
)
if
not
self
.
final_stage
:
self
.
send_skip_tensors
(
batch
,
i
,
skip_trackers
)
SendOperator
.
apply
(
self
.
transport
,
[
*
batch
],
i
)
for
portal
in
skip_trackers
[
i
].
portals
.
values
():
portal
.
pipeline
=
self
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
if
i
<
self
.
checkpoint_stop
:
chk
.
recompute
(
batch
)
batches
[
i
]
=
batch
def
get_batch_from_previous_stage
(
def
get_batch_from_previous_stage
(
self
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
batches
:
List
[
Batch
]
self
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
batches
:
List
[
Batch
]
)
->
Batch
:
)
->
Batch
:
phony
=
torch
.
empty
(
0
,
device
=
self
.
input_device
,
requires_grad
=
True
)
phony
=
torch
.
empty
(
0
,
device
=
self
.
input_device
,
requires_grad
=
True
)
result
=
RecvOperator
.
apply
(
torch
.
distributed
.
get_rank
(),
phony
,
self
.
transport
,
i
)
result
=
RecvOperator
.
apply
(
phony
,
self
.
transport
,
i
)
if
len
(
result
)
==
1
:
if
len
(
result
)
==
1
:
batch
=
Batch
(
result
[
0
],
i
)
batch
=
Batch
(
result
[
0
],
i
)
else
:
else
:
...
@@ -231,9 +194,10 @@ class MultiProcessPipeline:
...
@@ -231,9 +194,10 @@ class MultiProcessPipeline:
return
batch
return
batch
def
send_skip_tensors
(
def
send_skip_tensors
(
self
,
batch
:
Batch
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
])
->
None
:
self
,
this_rank
:
int
,
ranks
:
List
[
int
],
batch
:
Batch
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
]
ranks
=
get_pipeline_parallel_ranks
()
)
->
None
:
this_rank
=
torch
.
distributed
.
get_rank
()
for
next_j
,
ns
,
name
in
self
.
skip_layout
.
copy_policy_by_src
(
self
.
group
.
rank
()):
for
next_j
,
ns
,
name
in
self
.
skip_layout
.
copy_policy_by_src
(
self
.
group
.
rank
()):
life
=
skip_trackers
[
i
].
portals
[(
ns
,
name
)].
tensor_life
life
=
skip_trackers
[
i
].
portals
[(
ns
,
name
)].
tensor_life
loaded
=
skip_trackers
[
i
].
load
(
batch
,
ns
,
name
)
loaded
=
skip_trackers
[
i
].
load
(
batch
,
ns
,
name
)
...
@@ -271,25 +235,6 @@ class MultiProcessPipeline:
...
@@ -271,25 +235,6 @@ class MultiProcessPipeline:
except
QueueEmpty
:
except
QueueEmpty
:
break
break
def
execute_task
(
self
,
task
:
Task
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
])
->
Batch
:
batch
=
task
.
compute
()
rank
=
self
.
group
.
rank
()
if
not
self
.
final_stage
:
ranks
=
get_pipeline_parallel_ranks
()
this_rank
=
torch
.
distributed
.
get_rank
()
self
.
send_skip_tensors
(
this_rank
,
ranks
,
batch
,
i
,
skip_trackers
)
SendOperator
.
apply
(
this_rank
,
ranks
[
ranks
.
index
(
this_rank
)
+
1
],
self
.
transport
,
[
*
batch
],
i
)
for
portal
in
skip_trackers
[
i
].
portals
.
values
():
portal
.
pipeline
=
self
task
.
finalize
(
batch
)
return
batch
def
send_portal_grad
(
self
,
ns_name
:
Tuple
[
Namespace
,
str
],
index
:
int
,
grad
:
TensorOrTensors
)
->
None
:
def
send_portal_grad
(
self
,
ns_name
:
Tuple
[
Namespace
,
str
],
index
:
int
,
grad
:
TensorOrTensors
)
->
None
:
dest
,
src
=
self
.
skip_layout
.
by_ns_name
.
get
(
ns_name
,
(
-
1
,
-
1
))
dest
,
src
=
self
.
skip_layout
.
by_ns_name
.
get
(
ns_name
,
(
-
1
,
-
1
))
if
dest
==
src
:
if
dest
==
src
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment