Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
2eb1b8ec
"vscode:/vscode.git/clone" did not exist on "2f4146a4c2f48400dc02f86dcdcb3325a1fa0c7d"
Unverified
Commit
2eb1b8ec
authored
Jan 28, 2021
by
msbaines
Committed by
GitHub
Jan 28, 2021
Browse files
[cleanup] multiprocess_pipe: dead-code removal and simplification (#335)
parent
65ca68a9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
44 additions
and
85 deletions
+44
-85
fairscale/nn/pipe/async_schedule.py
fairscale/nn/pipe/async_schedule.py
+5
-12
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+8
-16
fairscale/nn/pipe/multiprocess_pipeline.py
fairscale/nn/pipe/multiprocess_pipeline.py
+31
-57
No files found.
fairscale/nn/pipe/async_schedule.py
View file @
2eb1b8ec
...
...
@@ -18,7 +18,7 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from
.messages
import
Transport
from
.microbatch
import
Batch
from
.skip.tracker
import
SkipTrackerThroughPotals
from
.types
import
EVENT_LOOP_QUEUE
,
PipelineStyle
,
PipeMessage
,
Tensors
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
Tensors
@
dataclass
(
frozen
=
True
)
...
...
@@ -190,17 +190,13 @@ class AsyncEventLoop:
)
->
Batch
:
"""Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed."""
assert
self
.
group
# We import here to avoid a cyclic dependency.
# TODO(msb) Break the cyclic dependency.
from
.multiprocess_pipeline
import
create_task
task
=
create_task
(
PipelineStyle
.
AsyncSchedule
,
self
.
checkpoint_stop
,
batch
.
index
,
self
.
group
.
rank
(),
batch
,
partition
.
module
,
skip_trackers
,
self
.
checkpoint_stop
,
batch
.
index
,
self
.
group
.
rank
(),
batch
,
partition
.
module
,
skip_trackers
,
)
result
=
task
.
compute
()
task
.
finalize
(
result
)
...
...
@@ -316,8 +312,6 @@ class AsyncEventLoop:
calculated. This also handles the first/only stage for the special
case of a 1-stage pipeline."""
assert
self
.
group
invocations
,
activations
=
self
.
get_invocations_and_activations
()
expected_invocations
=
len
(
invocations
)
*
len
(
batches
)
actual_invocations
=
0
...
...
@@ -379,7 +373,6 @@ class AsyncEventLoop:
def
event_loop
(
self
,
num_microbatch
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
])
->
None
:
"""The event loop for the "middle", i.e. neither the head nor the tail"""
assert
self
.
group
invocations
,
activations
=
self
.
get_invocations_and_activations
()
...
...
fairscale/nn/pipe/multiprocess_pipe.py
View file @
2eb1b8ec
...
...
@@ -36,6 +36,7 @@ from . import microbatch
from
.async_schedule
import
Invocation
,
Location
,
ModuleWrapper
from
.batchnorm
import
DeferredBatchNorm
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.phony
import
get_phony
from
.skip.layout
import
SkipLayout
,
inspect_skip_layout
from
.skip.skippable
import
Skippable
,
verify_skippables
from
.types
import
LazyModule
,
PipelineStyle
...
...
@@ -43,9 +44,6 @@ from .types import LazyModule, PipelineStyle
__all__
=
[
"MultiProcessPipe"
,
"LazyModule"
]
Device
=
Union
[
torch
.
device
,
int
,
str
]
Devices
=
Union
[
Iterable
[
Device
],
List
[
Device
]]
Tensors
=
Tuple
[
Tensor
,
...]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
...
...
@@ -579,10 +577,6 @@ class MultiProcessPipe(Module):
"""
microbatch
.
check
(
input
)
if
not
self
.
group
:
# Empty sequential module is not illegal.
return
input
if
not
self
.
pipeline
:
# No pipeline is not illegal, more ranks than partitions
return
input
...
...
@@ -594,19 +588,12 @@ class MultiProcessPipe(Module):
with
self
.
lock
:
self
.
pipeline
.
run
(
self
.
training
,
batches
,
event
)
if
not
self
.
final_stage
:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
return
batches
# type: ignore
else
:
if
self
.
final_stage
:
# Merge the micro-batches into one mini-batch.
if
self
.
pipelined_backward
:
with
torch
.
no_grad
():
output
=
microbatch
.
gather
(
batches
)
from
.phony
import
get_phony
phony
=
get_phony
(
torch
.
device
(
torch
.
cuda
.
current_device
()
if
torch
.
cuda
.
is_available
()
else
"cpu"
),
requires_grad
=
True
,
...
...
@@ -614,6 +601,11 @@ class MultiProcessPipe(Module):
output
=
PipelinedBackwardPass
.
apply
(
output
,
batches
,
phony
,
True
)
# self.retain_graph)
else
:
output
=
microbatch
.
gather
(
batches
)
else
:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
output
=
batches
# type: ignore
return
output
...
...
@@ -622,7 +614,7 @@ class MultiProcessPipe(Module):
raise
ValueError
(
"back_helper should only be called on non-final stages"
)
if
self
.
pipeline
:
self
.
pipeline
.
back_helper
(
list
(
reversed
(
output
)
))
self
.
pipeline
.
back_helper
(
output
)
class
PipelinedBackwardPass
(
torch
.
autograd
.
Function
):
...
...
fairscale/nn/pipe/multiprocess_pipeline.py
View file @
2eb1b8ec
...
...
@@ -78,7 +78,7 @@ class RecvOperator(torch.autograd.Function):
@
staticmethod
# type: ignore
def
forward
(
ctx
,
dst_rank
:
int
,
tensor
:
Tensor
,
input_device
,
transport
:
Transport
,
index
:
int
)
->
Tensors
:
def
forward
(
ctx
,
dst_rank
:
int
,
tensor
:
Tensor
,
transport
:
Transport
,
index
:
int
)
->
Tensors
:
assert
dst_rank
==
torch
.
distributed
.
get_rank
()
ctx
.
transport
=
transport
ctx
.
index
=
index
...
...
@@ -120,7 +120,6 @@ else:
def
create_task
(
style
:
PipelineStyle
,
checkpoint_stop
:
int
,
i
:
int
,
j
:
int
,
...
...
@@ -176,7 +175,7 @@ class MultiProcessPipeline:
skip_layout
:
SkipLayout
,
checkpoint_stop
:
int
,
style
:
PipelineStyle
,
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
group
:
torch
.
distributed
.
ProcessGroup
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
final_stage
:
bool
=
False
,
...
...
@@ -193,7 +192,6 @@ class MultiProcessPipeline:
input_device
=
input_device
,
)
self
.
input_device
=
input_device
self
.
all_at_once
=
False
self
.
callcount
=
0
self
.
final_stage
=
final_stage
...
...
@@ -219,11 +217,9 @@ class MultiProcessPipeline:
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
len
(
batches
))]
if
self
.
style
is
PipelineStyle
.
MultiProcess
:
assert
self
.
group
schedule
=
[(
i
,
self
.
group
.
rank
())
for
i
in
range
(
m
)]
self
.
compute
(
batches
,
schedule
,
skip_trackers
)
elif
self
.
style
is
PipelineStyle
.
AsyncSchedule
:
assert
self
.
group
rank
=
self
.
group
.
rank
()
event_loop
=
AsyncEventLoop
(
self
.
partitions
,
self
.
group
,
self
.
transport
,
self
.
training
,
self
.
checkpoint_stop
,
...
...
@@ -248,7 +244,7 @@ class MultiProcessPipeline:
)
->
Batch
:
phony
=
torch
.
empty
(
0
,
device
=
self
.
input_device
,
requires_grad
=
True
)
result
=
RecvOperator
.
apply
(
torch
.
distributed
.
get_rank
(),
phony
,
self
.
input_device
,
self
.
transport
,
i
)
result
=
RecvOperator
.
apply
(
torch
.
distributed
.
get_rank
(),
phony
,
self
.
transport
,
i
)
if
len
(
result
)
==
1
:
batch
=
Batch
(
result
[
0
],
i
)
else
:
...
...
@@ -261,7 +257,6 @@ class MultiProcessPipeline:
def
send_skip_tensors
(
self
,
this_rank
:
int
,
ranks
:
List
[
int
],
batch
:
Batch
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
]
)
->
None
:
assert
self
.
group
for
next_j
,
ns
,
name
in
self
.
skip_layout
.
copy_policy_by_src
(
self
.
group
.
rank
()):
life
=
skip_trackers
[
i
].
portals
[(
ns
,
name
)].
tensor_life
loaded
=
skip_trackers
[
i
].
load
(
batch
,
ns
,
name
)
...
...
@@ -302,7 +297,6 @@ class MultiProcessPipeline:
def
execute_task
(
self
,
task
:
Task
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
])
->
Batch
:
batch
=
task
.
compute
()
assert
self
.
group
rank
=
self
.
group
.
rank
()
if
self
.
style
is
PipelineStyle
.
MultiProcess
and
not
self
.
final_stage
:
...
...
@@ -324,9 +318,7 @@ class MultiProcessPipeline:
)
->
None
:
"""Runs tasks with synchronization to copy streams."""
if
self
.
style
is
PipelineStyle
.
MultiProcess
:
assert
self
.
group
n
=
self
.
group
.
size
()
assert
self
.
style
is
PipelineStyle
.
MultiProcess
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
...
...
@@ -354,19 +346,17 @@ class MultiProcessPipeline:
# │ Copy │
# └─────┰──────┘
for
i
,
j
in
schedule
:
batch
=
batches
[
i
]
if
self
.
style
is
PipelineStyle
.
MultiProcess
:
assert
len
(
self
.
partitions
)
==
1
partition
=
self
.
partitions
[
0
]
assert
len
(
self
.
partitions
)
==
1
partition
=
self
.
partitions
[
0
]
assert
self
.
group
if
self
.
group
.
rank
()
!=
0
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
if
self
.
group
.
rank
()
!=
0
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
else
:
batch
=
batches
[
i
]
task
=
create_task
(
self
.
style
,
self
.
checkpoint_stop
,
i
,
j
,
batch
,
partition
.
module
,
skip_trackers
)
task
=
create_task
(
self
.
checkpoint_stop
,
i
,
j
,
batch
,
partition
.
module
,
skip_trackers
)
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
def
send_portal_grad
(
self
,
ns_name
:
Tuple
[
Namespace
,
str
],
index
:
int
,
grad
:
TensorOrTensors
)
->
None
:
dest
,
src
=
self
.
skip_layout
.
by_ns_name
.
get
(
ns_name
,
(
-
1
,
-
1
))
...
...
@@ -398,43 +388,27 @@ class MultiProcessPipeline:
if
self
.
style
==
PipelineStyle
.
AsyncSchedule
:
return
o
=
list
(
output
)
tensors
:
Tensors
if
self
.
all_at_once
:
# FIXME(tom) allow specifying this branch when constructing Pipe(), add a test
rank
=
torch
.
distributed
.
get_rank
()
for
batch
in
reversed
(
output
):
found
=
self
.
transport
.
get_out_of_order
(
ACTIVATIONS_GRADS_QUEUE
,
batch
.
index
)
if
batch
.
atomic
:
tensors
=
tuple
([
batch
.
tensor
])
else
:
tensors
=
batch
.
tensors
if
len
(
found
)
!=
len
(
tensors
):
raise
RuntimeError
(
"different number of tensors and gradients"
)
grads
=
[]
f
or
i
,
batch
in
enumerate
(
o
):
rank
=
torch
.
distributed
.
get_rank
()
found
=
self
.
tra
ns
p
or
t
.
get_out_of_order
(
ACTIVATIONS_GRADS_QUEUE
,
i
)
assert
l
en
(
found
)
==
1
grad
s
.
append
(
found
[
0
]
)
tensors
=
tuple
(
x
.
tensor_or_tensors
for
x
in
o
)
# type: ignore
f
inal_tensors
=
[]
for
i
,
tensor
in
enumerate
(
tensors
):
if
tensor
.
requires_grad
or
getattr
(
te
nsor
,
"grad_fn"
,
None
)
is
not
None
:
grads
.
app
en
d
(
found
[
i
])
final_tensor
s
.
append
(
tensor
)
try
:
torch
.
autograd
.
backward
(
tensors
,
grad_tensors
=
grads
,
retain_graph
=
True
)
torch
.
autograd
.
backward
(
final_
tensors
,
grad_tensors
=
grads
,
retain_graph
=
True
)
except
Exception
as
e
:
raise
RuntimeError
(
"Autograd failed"
)
from
e
else
:
rank
=
torch
.
distributed
.
get_rank
()
for
batch
in
o
:
found
=
self
.
transport
.
get_out_of_order
(
ACTIVATIONS_GRADS_QUEUE
,
batch
.
index
)
if
batch
.
atomic
:
tensors
=
tuple
([
batch
.
tensor
])
else
:
tensors
=
batch
.
tensors
if
len
(
found
)
!=
len
(
tensors
):
raise
RuntimeError
(
"different number of tensors and gradients"
)
grads
=
[]
final_tensors
=
[]
for
i
,
tensor
in
enumerate
(
tensors
):
if
tensor
.
requires_grad
or
getattr
(
tensor
,
"grad_fn"
,
None
)
is
not
None
:
grads
.
append
(
found
[
i
])
final_tensors
.
append
(
tensor
)
try
:
torch
.
autograd
.
backward
(
final_tensors
,
grad_tensors
=
grads
,
retain_graph
=
True
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Autograd failed on
{
torch
.
distributed
.
get_rank
()
}
"
)
from
e
raise
RuntimeError
(
f
"Autograd failed on
{
torch
.
distributed
.
get_rank
()
}
"
)
from
e
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment