Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
2eb1b8ec
Unverified
Commit
2eb1b8ec
authored
Jan 28, 2021
by
msbaines
Committed by
GitHub
Jan 28, 2021
Browse files
[cleanup] multiprocess_pipe: dead-code removal and simplification (#335)
parent
65ca68a9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
44 additions
and
85 deletions
+44
-85
fairscale/nn/pipe/async_schedule.py
fairscale/nn/pipe/async_schedule.py
+5
-12
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+8
-16
fairscale/nn/pipe/multiprocess_pipeline.py
fairscale/nn/pipe/multiprocess_pipeline.py
+31
-57
No files found.
fairscale/nn/pipe/async_schedule.py
View file @
2eb1b8ec
...
@@ -18,7 +18,7 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
...
@@ -18,7 +18,7 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from
.messages
import
Transport
from
.messages
import
Transport
from
.microbatch
import
Batch
from
.microbatch
import
Batch
from
.skip.tracker
import
SkipTrackerThroughPotals
from
.skip.tracker
import
SkipTrackerThroughPotals
from
.types
import
EVENT_LOOP_QUEUE
,
PipelineStyle
,
PipeMessage
,
Tensors
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
Tensors
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
@@ -190,17 +190,13 @@ class AsyncEventLoop:
...
@@ -190,17 +190,13 @@ class AsyncEventLoop:
)
->
Batch
:
)
->
Batch
:
"""Actually run the forward pass for a given module, and send the result
"""Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed."""
to the next stage in the pipeline if needed."""
assert
self
.
group
# We import here to avoid a cyclic dependency.
# TODO(msb) Break the cyclic dependency.
from
.multiprocess_pipeline
import
create_task
from
.multiprocess_pipeline
import
create_task
task
=
create_task
(
task
=
create_task
(
PipelineStyle
.
AsyncSchedule
,
self
.
checkpoint_stop
,
batch
.
index
,
self
.
group
.
rank
(),
batch
,
partition
.
module
,
skip_trackers
,
self
.
checkpoint_stop
,
batch
.
index
,
self
.
group
.
rank
(),
batch
,
partition
.
module
,
skip_trackers
,
)
)
result
=
task
.
compute
()
result
=
task
.
compute
()
task
.
finalize
(
result
)
task
.
finalize
(
result
)
...
@@ -316,8 +312,6 @@ class AsyncEventLoop:
...
@@ -316,8 +312,6 @@ class AsyncEventLoop:
calculated. This also handles the first/only stage for the special
calculated. This also handles the first/only stage for the special
case of a 1-stage pipeline."""
case of a 1-stage pipeline."""
assert
self
.
group
invocations
,
activations
=
self
.
get_invocations_and_activations
()
invocations
,
activations
=
self
.
get_invocations_and_activations
()
expected_invocations
=
len
(
invocations
)
*
len
(
batches
)
expected_invocations
=
len
(
invocations
)
*
len
(
batches
)
actual_invocations
=
0
actual_invocations
=
0
...
@@ -379,7 +373,6 @@ class AsyncEventLoop:
...
@@ -379,7 +373,6 @@ class AsyncEventLoop:
def
event_loop
(
self
,
num_microbatch
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
])
->
None
:
def
event_loop
(
self
,
num_microbatch
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
])
->
None
:
"""The event loop for the "middle", i.e. neither the head nor the tail"""
"""The event loop for the "middle", i.e. neither the head nor the tail"""
assert
self
.
group
invocations
,
activations
=
self
.
get_invocations_and_activations
()
invocations
,
activations
=
self
.
get_invocations_and_activations
()
...
...
fairscale/nn/pipe/multiprocess_pipe.py
View file @
2eb1b8ec
...
@@ -36,6 +36,7 @@ from . import microbatch
...
@@ -36,6 +36,7 @@ from . import microbatch
from
.async_schedule
import
Invocation
,
Location
,
ModuleWrapper
from
.async_schedule
import
Invocation
,
Location
,
ModuleWrapper
from
.batchnorm
import
DeferredBatchNorm
from
.batchnorm
import
DeferredBatchNorm
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.phony
import
get_phony
from
.skip.layout
import
SkipLayout
,
inspect_skip_layout
from
.skip.layout
import
SkipLayout
,
inspect_skip_layout
from
.skip.skippable
import
Skippable
,
verify_skippables
from
.skip.skippable
import
Skippable
,
verify_skippables
from
.types
import
LazyModule
,
PipelineStyle
from
.types
import
LazyModule
,
PipelineStyle
...
@@ -43,9 +44,6 @@ from .types import LazyModule, PipelineStyle
...
@@ -43,9 +44,6 @@ from .types import LazyModule, PipelineStyle
__all__
=
[
"MultiProcessPipe"
,
"LazyModule"
]
__all__
=
[
"MultiProcessPipe"
,
"LazyModule"
]
Device
=
Union
[
torch
.
device
,
int
,
str
]
Devices
=
Union
[
Iterable
[
Device
],
List
[
Device
]]
Tensors
=
Tuple
[
Tensor
,
...]
Tensors
=
Tuple
[
Tensor
,
...]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
...
@@ -579,10 +577,6 @@ class MultiProcessPipe(Module):
...
@@ -579,10 +577,6 @@ class MultiProcessPipe(Module):
"""
"""
microbatch
.
check
(
input
)
microbatch
.
check
(
input
)
if
not
self
.
group
:
# Empty sequential module is not illegal.
return
input
if
not
self
.
pipeline
:
if
not
self
.
pipeline
:
# No pipeline is not illegal, more ranks than partitions
# No pipeline is not illegal, more ranks than partitions
return
input
return
input
...
@@ -594,19 +588,12 @@ class MultiProcessPipe(Module):
...
@@ -594,19 +588,12 @@ class MultiProcessPipe(Module):
with
self
.
lock
:
with
self
.
lock
:
self
.
pipeline
.
run
(
self
.
training
,
batches
,
event
)
self
.
pipeline
.
run
(
self
.
training
,
batches
,
event
)
if
not
self
.
final_stage
:
if
self
.
final_stage
:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
return
batches
# type: ignore
else
:
# Merge the micro-batches into one mini-batch.
# Merge the micro-batches into one mini-batch.
if
self
.
pipelined_backward
:
if
self
.
pipelined_backward
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
microbatch
.
gather
(
batches
)
output
=
microbatch
.
gather
(
batches
)
from
.phony
import
get_phony
phony
=
get_phony
(
phony
=
get_phony
(
torch
.
device
(
torch
.
cuda
.
current_device
()
if
torch
.
cuda
.
is_available
()
else
"cpu"
),
torch
.
device
(
torch
.
cuda
.
current_device
()
if
torch
.
cuda
.
is_available
()
else
"cpu"
),
requires_grad
=
True
,
requires_grad
=
True
,
...
@@ -614,6 +601,11 @@ class MultiProcessPipe(Module):
...
@@ -614,6 +601,11 @@ class MultiProcessPipe(Module):
output
=
PipelinedBackwardPass
.
apply
(
output
,
batches
,
phony
,
True
)
# self.retain_graph)
output
=
PipelinedBackwardPass
.
apply
(
output
,
batches
,
phony
,
True
)
# self.retain_graph)
else
:
else
:
output
=
microbatch
.
gather
(
batches
)
output
=
microbatch
.
gather
(
batches
)
else
:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
output
=
batches
# type: ignore
return
output
return
output
...
@@ -622,7 +614,7 @@ class MultiProcessPipe(Module):
...
@@ -622,7 +614,7 @@ class MultiProcessPipe(Module):
raise
ValueError
(
"back_helper should only be called on non-final stages"
)
raise
ValueError
(
"back_helper should only be called on non-final stages"
)
if
self
.
pipeline
:
if
self
.
pipeline
:
self
.
pipeline
.
back_helper
(
list
(
reversed
(
output
)
))
self
.
pipeline
.
back_helper
(
output
)
class
PipelinedBackwardPass
(
torch
.
autograd
.
Function
):
class
PipelinedBackwardPass
(
torch
.
autograd
.
Function
):
...
...
fairscale/nn/pipe/multiprocess_pipeline.py
View file @
2eb1b8ec
...
@@ -78,7 +78,7 @@ class RecvOperator(torch.autograd.Function):
...
@@ -78,7 +78,7 @@ class RecvOperator(torch.autograd.Function):
@
staticmethod
@
staticmethod
# type: ignore
# type: ignore
def
forward
(
ctx
,
dst_rank
:
int
,
tensor
:
Tensor
,
input_device
,
transport
:
Transport
,
index
:
int
)
->
Tensors
:
def
forward
(
ctx
,
dst_rank
:
int
,
tensor
:
Tensor
,
transport
:
Transport
,
index
:
int
)
->
Tensors
:
assert
dst_rank
==
torch
.
distributed
.
get_rank
()
assert
dst_rank
==
torch
.
distributed
.
get_rank
()
ctx
.
transport
=
transport
ctx
.
transport
=
transport
ctx
.
index
=
index
ctx
.
index
=
index
...
@@ -120,7 +120,6 @@ else:
...
@@ -120,7 +120,6 @@ else:
def
create_task
(
def
create_task
(
style
:
PipelineStyle
,
checkpoint_stop
:
int
,
checkpoint_stop
:
int
,
i
:
int
,
i
:
int
,
j
:
int
,
j
:
int
,
...
@@ -176,7 +175,7 @@ class MultiProcessPipeline:
...
@@ -176,7 +175,7 @@ class MultiProcessPipeline:
skip_layout
:
SkipLayout
,
skip_layout
:
SkipLayout
,
checkpoint_stop
:
int
,
checkpoint_stop
:
int
,
style
:
PipelineStyle
,
style
:
PipelineStyle
,
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
group
:
torch
.
distributed
.
ProcessGroup
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
final_stage
:
bool
=
False
,
final_stage
:
bool
=
False
,
...
@@ -193,7 +192,6 @@ class MultiProcessPipeline:
...
@@ -193,7 +192,6 @@ class MultiProcessPipeline:
input_device
=
input_device
,
input_device
=
input_device
,
)
)
self
.
input_device
=
input_device
self
.
input_device
=
input_device
self
.
all_at_once
=
False
self
.
callcount
=
0
self
.
callcount
=
0
self
.
final_stage
=
final_stage
self
.
final_stage
=
final_stage
...
@@ -219,11 +217,9 @@ class MultiProcessPipeline:
...
@@ -219,11 +217,9 @@ class MultiProcessPipeline:
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
len
(
batches
))]
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
len
(
batches
))]
if
self
.
style
is
PipelineStyle
.
MultiProcess
:
if
self
.
style
is
PipelineStyle
.
MultiProcess
:
assert
self
.
group
schedule
=
[(
i
,
self
.
group
.
rank
())
for
i
in
range
(
m
)]
schedule
=
[(
i
,
self
.
group
.
rank
())
for
i
in
range
(
m
)]
self
.
compute
(
batches
,
schedule
,
skip_trackers
)
self
.
compute
(
batches
,
schedule
,
skip_trackers
)
elif
self
.
style
is
PipelineStyle
.
AsyncSchedule
:
elif
self
.
style
is
PipelineStyle
.
AsyncSchedule
:
assert
self
.
group
rank
=
self
.
group
.
rank
()
rank
=
self
.
group
.
rank
()
event_loop
=
AsyncEventLoop
(
event_loop
=
AsyncEventLoop
(
self
.
partitions
,
self
.
group
,
self
.
transport
,
self
.
training
,
self
.
checkpoint_stop
,
self
.
partitions
,
self
.
group
,
self
.
transport
,
self
.
training
,
self
.
checkpoint_stop
,
...
@@ -248,7 +244,7 @@ class MultiProcessPipeline:
...
@@ -248,7 +244,7 @@ class MultiProcessPipeline:
)
->
Batch
:
)
->
Batch
:
phony
=
torch
.
empty
(
0
,
device
=
self
.
input_device
,
requires_grad
=
True
)
phony
=
torch
.
empty
(
0
,
device
=
self
.
input_device
,
requires_grad
=
True
)
result
=
RecvOperator
.
apply
(
torch
.
distributed
.
get_rank
(),
phony
,
self
.
input_device
,
self
.
transport
,
i
)
result
=
RecvOperator
.
apply
(
torch
.
distributed
.
get_rank
(),
phony
,
self
.
transport
,
i
)
if
len
(
result
)
==
1
:
if
len
(
result
)
==
1
:
batch
=
Batch
(
result
[
0
],
i
)
batch
=
Batch
(
result
[
0
],
i
)
else
:
else
:
...
@@ -261,7 +257,6 @@ class MultiProcessPipeline:
...
@@ -261,7 +257,6 @@ class MultiProcessPipeline:
def
send_skip_tensors
(
def
send_skip_tensors
(
self
,
this_rank
:
int
,
ranks
:
List
[
int
],
batch
:
Batch
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
]
self
,
this_rank
:
int
,
ranks
:
List
[
int
],
batch
:
Batch
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
]
)
->
None
:
)
->
None
:
assert
self
.
group
for
next_j
,
ns
,
name
in
self
.
skip_layout
.
copy_policy_by_src
(
self
.
group
.
rank
()):
for
next_j
,
ns
,
name
in
self
.
skip_layout
.
copy_policy_by_src
(
self
.
group
.
rank
()):
life
=
skip_trackers
[
i
].
portals
[(
ns
,
name
)].
tensor_life
life
=
skip_trackers
[
i
].
portals
[(
ns
,
name
)].
tensor_life
loaded
=
skip_trackers
[
i
].
load
(
batch
,
ns
,
name
)
loaded
=
skip_trackers
[
i
].
load
(
batch
,
ns
,
name
)
...
@@ -302,7 +297,6 @@ class MultiProcessPipeline:
...
@@ -302,7 +297,6 @@ class MultiProcessPipeline:
def
execute_task
(
self
,
task
:
Task
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
])
->
Batch
:
def
execute_task
(
self
,
task
:
Task
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
])
->
Batch
:
batch
=
task
.
compute
()
batch
=
task
.
compute
()
assert
self
.
group
rank
=
self
.
group
.
rank
()
rank
=
self
.
group
.
rank
()
if
self
.
style
is
PipelineStyle
.
MultiProcess
and
not
self
.
final_stage
:
if
self
.
style
is
PipelineStyle
.
MultiProcess
and
not
self
.
final_stage
:
...
@@ -324,9 +318,7 @@ class MultiProcessPipeline:
...
@@ -324,9 +318,7 @@ class MultiProcessPipeline:
)
->
None
:
)
->
None
:
"""Runs tasks with synchronization to copy streams."""
"""Runs tasks with synchronization to copy streams."""
if
self
.
style
is
PipelineStyle
.
MultiProcess
:
assert
self
.
style
is
PipelineStyle
.
MultiProcess
assert
self
.
group
n
=
self
.
group
.
size
()
# With checkpointing, the autograd graph looks like this diagram:
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# ┌─────┸──────┐
...
@@ -354,17 +346,15 @@ class MultiProcessPipeline:
...
@@ -354,17 +346,15 @@ class MultiProcessPipeline:
# │ Copy │
# │ Copy │
# └─────┰──────┘
# └─────┰──────┘
for
i
,
j
in
schedule
:
for
i
,
j
in
schedule
:
batch
=
batches
[
i
]
if
self
.
style
is
PipelineStyle
.
MultiProcess
:
assert
len
(
self
.
partitions
)
==
1
assert
len
(
self
.
partitions
)
==
1
partition
=
self
.
partitions
[
0
]
partition
=
self
.
partitions
[
0
]
assert
self
.
group
if
self
.
group
.
rank
()
!=
0
:
if
self
.
group
.
rank
()
!=
0
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
else
:
batch
=
batches
[
i
]
task
=
create_task
(
self
.
style
,
self
.
checkpoint_stop
,
i
,
j
,
batch
,
partition
.
module
,
skip_trackers
)
task
=
create_task
(
self
.
checkpoint_stop
,
i
,
j
,
batch
,
partition
.
module
,
skip_trackers
)
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
...
@@ -398,26 +388,10 @@ class MultiProcessPipeline:
...
@@ -398,26 +388,10 @@ class MultiProcessPipeline:
if
self
.
style
==
PipelineStyle
.
AsyncSchedule
:
if
self
.
style
==
PipelineStyle
.
AsyncSchedule
:
return
return
o
=
list
(
output
)
tensors
:
Tensors
tensors
:
Tensors
if
self
.
all_at_once
:
# FIXME(tom) allow specifying this branch when constructing Pipe(), add a test
grads
=
[]
for
i
,
batch
in
enumerate
(
o
):
rank
=
torch
.
distributed
.
get_rank
()
found
=
self
.
transport
.
get_out_of_order
(
ACTIVATIONS_GRADS_QUEUE
,
i
)
assert
len
(
found
)
==
1
grads
.
append
(
found
[
0
])
tensors
=
tuple
(
x
.
tensor_or_tensors
for
x
in
o
)
# type: ignore
try
:
torch
.
autograd
.
backward
(
tensors
,
grad_tensors
=
grads
,
retain_graph
=
True
)
except
Exception
as
e
:
raise
RuntimeError
(
"Autograd failed"
)
from
e
else
:
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
for
batch
in
o
:
for
batch
in
reversed
(
output
)
:
found
=
self
.
transport
.
get_out_of_order
(
ACTIVATIONS_GRADS_QUEUE
,
batch
.
index
)
found
=
self
.
transport
.
get_out_of_order
(
ACTIVATIONS_GRADS_QUEUE
,
batch
.
index
)
if
batch
.
atomic
:
if
batch
.
atomic
:
tensors
=
tuple
([
batch
.
tensor
])
tensors
=
tuple
([
batch
.
tensor
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment