Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
b1b9e0f8
Unverified
Commit
b1b9e0f8
authored
Feb 08, 2021
by
msbaines
Committed by
GitHub
Feb 08, 2021
Browse files
[refactor] remove multiprocess dependency on async (#373)
parent
08c10993
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
15 additions
and
24 deletions
+15
-24
fairscale/nn/pipe/async_pipe.py
fairscale/nn/pipe/async_pipe.py
+4
-0
fairscale/nn/pipe/async_schedule.py
fairscale/nn/pipe/async_schedule.py
+1
-4
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+1
-5
fairscale/nn/pipe/multiprocess_pipeline.py
fairscale/nn/pipe/multiprocess_pipeline.py
+4
-8
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+3
-5
tests/nn/pipe_process/test_transparency.py
tests/nn/pipe_process/test_transparency.py
+2
-2
No files found.
fairscale/nn/pipe/async_pipe.py
View file @
b1b9e0f8
...
...
@@ -192,6 +192,8 @@ class AsyncPipe(Module):
warnings
.
warn
(
"More ranks than partitions, some ranks unused"
)
self
.
partitions
:
List
[
ModuleWrapper
]
=
[]
self
.
pipeline
=
None
# TODO(msb) remove this hack
self
.
partition
=
None
else
:
self
.
partitions
=
self
.
instantiate_partition
(
module
,
self
.
balance
,
self
.
group
)
if
deferred_batch_norm
:
...
...
@@ -200,6 +202,8 @@ class AsyncPipe(Module):
for
name
,
part
in
enumerate
(
self
.
partitions
):
self
.
add_module
(
str
(
name
),
part
.
module
)
self
.
create_pipeline
()
# TODO(msb) remove this hack
self
.
partition
=
self
.
partitions
[
0
].
module
del
module
...
...
fairscale/nn/pipe/async_schedule.py
View file @
b1b9e0f8
...
...
@@ -17,6 +17,7 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from
.messages
import
Transport
from
.microbatch
import
Batch
from
.multiprocess_pipeline
import
create_task
from
.skip.tracker
import
SkipTrackerThroughPotals
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
Tensors
...
...
@@ -191,10 +192,6 @@ class AsyncEventLoop:
"""Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed."""
# We import here to avoid a cyclic dependency.
# TODO(msb) Break the cyclic dependency.
from
.multiprocess_pipeline
import
create_task
task
=
create_task
(
self
.
checkpoint_stop
,
batch
.
index
,
self
.
group
.
rank
(),
batch
,
partition
.
module
,
skip_trackers
,
)
...
...
fairscale/nn/pipe/multiprocess_pipe.py
View file @
b1b9e0f8
...
...
@@ -31,7 +31,6 @@ import torch.cuda
from
fairscale.nn.model_parallel
import
get_model_parallel_world_size
,
get_pipeline_parallel_group
from
.
import
microbatch
from
.async_schedule
import
Location
,
ModuleWrapper
from
.batchnorm
import
DeferredBatchNorm
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.phony
import
get_phony
...
...
@@ -219,9 +218,6 @@ class MultiProcessPipe(Module):
self
.
add_module
(
str
(
0
),
self
.
partition
)
self
.
create_pipeline
()
# TODO(msb) Remove this hack at some point.
self
.
partitions
=
[
ModuleWrapper
(
self
.
partition
,
Location
(
self
.
group
.
rank
(),
0
))]
del
module
def
create_pipeline
(
self
)
->
None
:
...
...
@@ -229,7 +225,7 @@ class MultiProcessPipe(Module):
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
self
.
pipeline
=
MultiProcessPipeline
(
[
ModuleWrapper
(
self
.
partition
,
Location
(
self
.
group
.
rank
(),
0
))]
,
self
.
partition
,
self
.
_skip_layout
,
checkpoint_stop
,
group
=
self
.
group
,
...
...
fairscale/nn/pipe/multiprocess_pipeline.py
View file @
b1b9e0f8
...
...
@@ -30,7 +30,6 @@ from torch.autograd.profiler import record_function
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_ranks
from
.async_schedule
import
ModuleWrapper
from
.checkpoint
import
Checkpointing
from
.messages
import
MakeTransport
,
Transport
from
.microbatch
import
Batch
...
...
@@ -162,7 +161,7 @@ class MultiProcessPipeline:
def
__init__
(
self
,
partition
s
:
List
[
ModuleWrapper
]
,
partition
:
nn
.
Sequential
,
skip_layout
:
SkipLayout
,
checkpoint_stop
:
int
,
group
:
torch
.
distributed
.
ProcessGroup
,
...
...
@@ -171,7 +170,7 @@ class MultiProcessPipeline:
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
final_stage
:
bool
=
False
,
)
->
None
:
self
.
partition
s
=
partition
s
self
.
partition
=
partition
self
.
skip_layout
=
skip_layout
self
.
__checkpoint_stop
=
checkpoint_stop
self
.
group
=
group
...
...
@@ -187,7 +186,7 @@ class MultiProcessPipeline:
@
property
def
checkpoint_stop
(
self
)
->
int
:
# Disable checkpointing if in eval mode.
training
=
self
.
partition
s
[
0
].
module
.
training
training
=
self
.
partition
.
training
if
not
training
:
return
0
return
self
.
__checkpoint_stop
...
...
@@ -208,15 +207,12 @@ class MultiProcessPipeline:
schedule
=
[(
i
,
self
.
group
.
rank
())
for
i
in
range
(
m
)]
for
i
,
j
in
schedule
:
assert
len
(
self
.
partitions
)
==
1
partition
=
self
.
partitions
[
0
]
if
self
.
group
.
rank
()
!=
0
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
else
:
batch
=
batches
[
i
]
task
=
create_task
(
self
.
checkpoint_stop
,
i
,
j
,
batch
,
partition
.
module
,
skip_trackers
)
task
=
create_task
(
self
.
checkpoint_stop
,
i
,
j
,
batch
,
self
.
partition
,
skip_trackers
)
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
...
...
tests/nn/pipe_process/test_pipe.py
View file @
b1b9e0f8
...
...
@@ -366,8 +366,8 @@ def no_grad(pipe_class):
nonlocal
latent
latent
=
output
partition
=
model
.
partition
s
[
0
]
partition
.
module
.
register_forward_hook
(
hook
)
partition
=
model
.
partition
partition
.
register_forward_hook
(
hook
)
with
torch
.
no_grad
():
model
(
input
)
...
...
@@ -616,9 +616,7 @@ def partitions(pipe_class):
model
=
nn
.
Sequential
(
a
,
b
)
model
=
pipe_class
(
model
,
[
1
,
1
],
worker_map
=
get_worker_map
())
assert
isinstance
(
model
.
partitions
,
list
)
assert
len
(
model
)
==
1
assert
isinstance
(
model
.
partitions
[
0
].
module
,
nn
.
Sequential
)
assert
isinstance
(
model
.
partition
,
nn
.
Sequential
)
if
model
.
group
.
rank
()
==
0
:
assert
model
[
0
].
weight
==
a
.
weight
...
...
tests/nn/pipe_process/test_transparency.py
View file @
b1b9e0f8
...
...
@@ -60,13 +60,13 @@ def simple_linears(pipe_class):
if
model
.
group
.
rank
()
==
1
:
loss
=
outputs
.
mean
()
loss
.
backward
()
grad_with_pipe
=
sum_grad
(
model
.
p
ipeline
.
partitions
[
0
].
module
.
parameters
())
grad_with_pipe
=
sum_grad
(
model
.
p
artition
.
parameters
())
# Both grads should be identical.
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
1
])
else
:
model
.
back_helper
(
outputs
)
grad_with_pipe
=
sum_grad
(
model
.
p
ipeline
.
partitions
[
0
].
module
.
parameters
())
grad_with_pipe
=
sum_grad
(
model
.
p
artition
.
parameters
())
# Both grads should be identical.
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment